Unverified Commit 78f13981 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

[1/N] DP-Refactor: move communicators into `tokenizer_communicator_mixin` (#10028)

parent bfd7a18d
...@@ -36,7 +36,8 @@ from sglang.srt.managers.io_struct import ( ...@@ -36,7 +36,8 @@ from sglang.srt.managers.io_struct import (
MultiTokenizerRegisterReq, MultiTokenizerRegisterReq,
MultiTokenizerWrapper, MultiTokenizerWrapper,
) )
from sglang.srt.managers.tokenizer_manager import TokenizerManager, _Communicator from sglang.srt.managers.tokenizer_communicator_mixin import _Communicator
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import get_zmq_socket, kill_process_tree from sglang.srt.utils import get_zmq_socket, kill_process_tree
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
......
from __future__ import annotations
import asyncio
import logging
import os
import time
from collections import deque
from typing import (
TYPE_CHECKING,
Any,
Deque,
Dict,
Generic,
List,
Optional,
Tuple,
TypeVar,
)
import fastapi
from sglang.srt.managers.io_struct import (
ClearHiCacheReqInput,
ClearHiCacheReqOutput,
ExpertDistributionReq,
ExpertDistributionReqOutput,
FlushCacheReqInput,
FlushCacheReqOutput,
GetInternalStateReq,
GetInternalStateReqOutput,
GetWeightsByNameReqInput,
GetWeightsByNameReqOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
LoadLoRAAdapterReqInput,
LoadLoRAAdapterReqOutput,
LoRAUpdateResult,
MultiTokenizerWrapper,
ProfileReq,
ProfileReqOutput,
ProfileReqType,
ReleaseMemoryOccupationReqInput,
ReleaseMemoryOccupationReqOutput,
ResumeMemoryOccupationReqInput,
ResumeMemoryOccupationReqOutput,
SetInternalStateReq,
SetInternalStateReqOutput,
SlowDownReqInput,
SlowDownReqOutput,
UnloadLoRAAdapterReqInput,
UnloadLoRAAdapterReqOutput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromDistributedReqOutput,
UpdateWeightsFromTensorReqInput,
UpdateWeightsFromTensorReqOutput,
)
from sglang.srt.server_args import LoRARef, ServerArgs
from sglang.srt.utils import get_bool_env_var
from sglang.utils import TypeBasedDispatcher
if TYPE_CHECKING:
from sglang.srt.managers.tokenizer_manager import TokenizerManager
T = TypeVar("T")
logger = logging.getLogger(__name__)
class _Communicator(Generic[T]):
"""Note: The communicator now only run up to 1 in-flight request at any time."""
enable_multi_tokenizer = False
def __init__(self, sender, fan_out: int):
self._sender = sender
self._fan_out = fan_out
self._result_event: Optional[asyncio.Event] = None
self._result_values: Optional[List[T]] = None
self._ready_queue: Deque[asyncio.Future] = deque()
async def __call__(self, obj):
ready_event = asyncio.Event()
if self._result_event is not None or len(self._ready_queue) > 0:
self._ready_queue.append(ready_event)
await ready_event.wait()
assert self._result_event is None
assert self._result_values is None
if obj:
if _Communicator.enable_multi_tokenizer:
obj = MultiTokenizerWrapper(worker_id=os.getpid(), obj=obj)
self._sender.send_pyobj(obj)
self._result_event = asyncio.Event()
self._result_values = []
await self._result_event.wait()
result_values = self._result_values
self._result_event = self._result_values = None
if len(self._ready_queue) > 0:
self._ready_queue.popleft().set()
return result_values
def handle_recv(self, recv_obj: T):
self._result_values.append(recv_obj)
if len(self._result_values) == self._fan_out:
self._result_event.set()
class TokenizerCommunicatorMixin:
"""Mixin class for TokenizerManager to handle communication with the scheduler."""
def init_communicators(self: TokenizerManager, server_args: ServerArgs):
# Communicators
self.init_weights_update_group_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.update_weights_from_distributed_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.update_weights_from_tensor_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.get_weights_by_name_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.release_memory_occupation_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.resume_memory_occupation_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.slow_down_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.flush_cache_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.clear_hicache_storage_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.profile_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.get_internal_state_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.set_internal_state_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.expert_distribution_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.update_lora_adapter_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self._result_dispatcher += self._get_communicator_dispatcher()
def _get_communicator_dispatcher(self: TokenizerManager):
return TypeBasedDispatcher(
[
(
InitWeightsUpdateGroupReqOutput,
self.init_weights_update_group_communicator.handle_recv,
),
(
UpdateWeightsFromDistributedReqOutput,
self.update_weights_from_distributed_communicator.handle_recv,
),
(
UpdateWeightsFromTensorReqOutput,
self.update_weights_from_tensor_communicator.handle_recv,
),
(
GetWeightsByNameReqOutput,
self.get_weights_by_name_communicator.handle_recv,
),
(
ReleaseMemoryOccupationReqOutput,
self.release_memory_occupation_communicator.handle_recv,
),
(
ResumeMemoryOccupationReqOutput,
self.resume_memory_occupation_communicator.handle_recv,
),
(
SlowDownReqOutput,
self.slow_down_communicator.handle_recv,
),
(
ClearHiCacheReqOutput,
self.clear_hicache_storage_communicator.handle_recv,
),
(
FlushCacheReqOutput,
self.flush_cache_communicator.handle_recv,
),
(
ProfileReqOutput,
self.profile_communicator.handle_recv,
),
(
GetInternalStateReqOutput,
self.get_internal_state_communicator.handle_recv,
),
(
SetInternalStateReqOutput,
self.set_internal_state_communicator.handle_recv,
),
(
ExpertDistributionReqOutput,
self.expert_distribution_communicator.handle_recv,
),
(
LoRAUpdateResult,
self.update_lora_adapter_communicator.handle_recv,
),
]
)
async def flush_cache(self: TokenizerManager) -> FlushCacheReqOutput:
return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
async def clear_hicache_storage(self: TokenizerManager) -> ClearHiCacheReqOutput:
"""Clear the hierarchical cache storage."""
# Delegate to the scheduler to handle HiCacheStorage clearing
return (await self.clear_hicache_storage_communicator(ClearHiCacheReqInput()))[
0
]
async def start_profile(
self: TokenizerManager,
output_dir: Optional[str] = None,
start_step: Optional[int] = None,
num_steps: Optional[int] = None,
activities: Optional[List[str]] = None,
with_stack: Optional[bool] = None,
record_shapes: Optional[bool] = None,
profile_by_stage: bool = False,
):
self.auto_create_handle_loop()
env_with_stack: bool = get_bool_env_var("SGLANG_PROFILE_WITH_STACK", "true")
with_stack = False if with_stack is False or env_with_stack is False else True
req = ProfileReq(
type=ProfileReqType.START_PROFILE,
output_dir=output_dir,
start_step=start_step,
num_steps=num_steps,
activities=activities,
with_stack=with_stack,
record_shapes=record_shapes,
profile_by_stage=profile_by_stage,
profile_id=str(time.time()),
)
return await self._execute_profile(req)
async def stop_profile(self: TokenizerManager):
self.auto_create_handle_loop()
req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
return await self._execute_profile(req)
async def _execute_profile(self: TokenizerManager, req: ProfileReq):
result = (await self.profile_communicator(req))[0]
if not result.success:
raise RuntimeError(result.message)
return result
async def start_expert_distribution_record(self: TokenizerManager):
self.auto_create_handle_loop()
await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD)
async def stop_expert_distribution_record(self: TokenizerManager):
self.auto_create_handle_loop()
await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD)
async def dump_expert_distribution_record(self: TokenizerManager):
self.auto_create_handle_loop()
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
async def init_weights_update_group(
self: TokenizerManager,
obj: InitWeightsUpdateGroupReqInput,
request: Optional[fastapi.Request] = None,
) -> Tuple[bool, str]:
self.auto_create_handle_loop()
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for init parameter update group"
result = (await self.init_weights_update_group_communicator(obj))[0]
return result.success, result.message
async def update_weights_from_distributed(
self: TokenizerManager,
obj: UpdateWeightsFromDistributedReqInput,
request: Optional[fastapi.Request] = None,
) -> Tuple[bool, str]:
self.auto_create_handle_loop()
assert (
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
), "dp_size must be 1 or dp attention must be enabled for update weights from distributed"
if obj.abort_all_requests:
self.abort_request(abort_all=True)
# This means that weight sync
# cannot run while requests are in progress.
async with self.model_update_lock.writer_lock:
result = (await self.update_weights_from_distributed_communicator(obj))[0]
return result.success, result.message
async def update_weights_from_tensor(
self: TokenizerManager,
obj: UpdateWeightsFromTensorReqInput,
request: Optional[fastapi.Request] = None,
) -> Tuple[bool, str]:
self.auto_create_handle_loop()
assert (
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
), "dp_size must be 1 or dp attention must be enabled for update weights from tensor"
if obj.abort_all_requests:
self.abort_request(abort_all=True)
# This means that weight sync
# cannot run while requests are in progress.
async with self.model_update_lock.writer_lock:
result = (await self.update_weights_from_tensor_communicator(obj))[0]
return result.success, result.message
async def load_lora_adapter(
self: TokenizerManager,
obj: LoadLoRAAdapterReqInput,
_: Optional[fastapi.Request] = None,
) -> LoadLoRAAdapterReqOutput:
self.auto_create_handle_loop()
try:
if not self.server_args.enable_lora:
raise ValueError(
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
)
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
# with dp_size > 1.
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for dynamic lora loading"
logger.info(
"Start load Lora adapter. Lora name=%s, path=%s",
obj.lora_name,
obj.lora_path,
)
async with self.lora_update_lock:
if (
self.server_args.max_loaded_loras is not None
and self.lora_registry.num_registered_loras
>= self.server_args.max_loaded_loras
):
raise ValueError(
f"Cannot load LoRA adapter {obj.lora_name} at path {obj.lora_path}. "
f"Maximum number of loaded LoRA adapters is {self.server_args.max_loaded_loras}. "
"Please unload some LoRA adapters before loading new ones."
)
# Generate new uniquely identifiable LoRARef object.
new_adapter = LoRARef(
lora_name=obj.lora_name,
lora_path=obj.lora_path,
pinned=obj.pinned,
)
# Trigger the actual loading operation at the backend processes.
obj.lora_id = new_adapter.lora_id
result = (await self.update_lora_adapter_communicator(obj))[0]
# Register the LoRA adapter only after loading is successful.
if result.success:
await self.lora_registry.register(new_adapter)
return result
except ValueError as e:
return LoadLoRAAdapterReqOutput(
success=False,
error_message=str(e),
)
async def unload_lora_adapter(
self: TokenizerManager,
obj: UnloadLoRAAdapterReqInput,
_: Optional[fastapi.Request] = None,
) -> UnloadLoRAAdapterReqOutput:
self.auto_create_handle_loop()
try:
if not self.server_args.enable_lora:
raise ValueError(
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
)
assert (
obj.lora_name is not None
), "lora_name must be provided to unload LoRA adapter"
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
# with dp_size > 1.
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for dynamic lora loading"
logger.info(
"Start unload Lora adapter. Lora name=%s",
obj.lora_name,
)
async with self.lora_update_lock:
# Unregister the LoRA adapter from the registry to stop new requests for this adapter
# from being started.
lora_id = await self.lora_registry.unregister(obj.lora_name)
obj.lora_id = lora_id
# Initiate the actual unloading operation at the backend processes only after all
# ongoing requests using this LoRA adapter are finished.
await self.lora_registry.wait_for_unload(lora_id)
result = (await self.update_lora_adapter_communicator(obj))[0]
return result
except ValueError as e:
return UnloadLoRAAdapterReqOutput(success=False, error_message=str(e))
async def get_weights_by_name(
self: TokenizerManager,
obj: GetWeightsByNameReqInput,
request: Optional[fastapi.Request] = None,
):
self.auto_create_handle_loop()
results = await self.get_weights_by_name_communicator(obj)
all_parameters = [r.parameter for r in results]
if self.server_args.dp_size == 1:
return all_parameters[0]
else:
return all_parameters
async def release_memory_occupation(
self: TokenizerManager,
obj: ReleaseMemoryOccupationReqInput,
request: Optional[fastapi.Request] = None,
):
self.auto_create_handle_loop()
await self.release_memory_occupation_communicator(obj)
async def resume_memory_occupation(
self: TokenizerManager,
obj: ResumeMemoryOccupationReqInput,
request: Optional[fastapi.Request] = None,
):
self.auto_create_handle_loop()
await self.resume_memory_occupation_communicator(obj)
async def slow_down(
self: TokenizerManager,
obj: SlowDownReqInput,
request: Optional[fastapi.Request] = None,
):
self.auto_create_handle_loop()
await self.slow_down_communicator(obj)
async def get_internal_state(self: TokenizerManager) -> List[Dict[Any, Any]]:
req = GetInternalStateReq()
responses: List[GetInternalStateReqOutput] = (
await self.get_internal_state_communicator(req)
)
# Many DP ranks
return [res.internal_state for res in responses]
async def set_internal_state(
self: TokenizerManager, obj: SetInternalStateReq
) -> List[bool]:
responses: List[SetInternalStateReqOutput] = (
await self.set_internal_state_communicator(obj)
)
return [res.updated for res in responses]
async def get_load(self: TokenizerManager) -> dict:
# TODO(lsyin): fake load report server
if not self.current_load_lock.locked():
async with self.current_load_lock:
internal_state = await self.get_internal_state()
self.current_load = internal_state[0]["load"]
return {"load": self.current_load}
...@@ -31,19 +31,7 @@ from contextlib import nullcontext ...@@ -31,19 +31,7 @@ from contextlib import nullcontext
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from http import HTTPStatus from http import HTTPStatus
from typing import ( from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
Any,
Awaitable,
Deque,
Dict,
Generic,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
import fastapi import fastapi
import torch import torch
...@@ -70,57 +58,26 @@ from sglang.srt.managers.io_struct import ( ...@@ -70,57 +58,26 @@ from sglang.srt.managers.io_struct import (
BatchTokenIDOut, BatchTokenIDOut,
BatchTokenizedEmbeddingReqInput, BatchTokenizedEmbeddingReqInput,
BatchTokenizedGenerateReqInput, BatchTokenizedGenerateReqInput,
ClearHiCacheReqInput,
ClearHiCacheReqOutput,
CloseSessionReqInput, CloseSessionReqInput,
ConfigureLoggingReq, ConfigureLoggingReq,
EmbeddingReqInput, EmbeddingReqInput,
ExpertDistributionReq,
ExpertDistributionReqOutput,
FlushCacheReqInput,
FlushCacheReqOutput,
FreezeGCReq, FreezeGCReq,
GenerateReqInput, GenerateReqInput,
GetInternalStateReq,
GetInternalStateReqOutput,
GetWeightsByNameReqInput,
GetWeightsByNameReqOutput,
HealthCheckOutput, HealthCheckOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
LoadLoRAAdapterReqInput,
LoadLoRAAdapterReqOutput,
LoRAUpdateResult,
MultiTokenizerWrapper, MultiTokenizerWrapper,
OpenSessionReqInput, OpenSessionReqInput,
OpenSessionReqOutput, OpenSessionReqOutput,
ProfileReq,
ProfileReqOutput,
ProfileReqType,
ReleaseMemoryOccupationReqInput,
ReleaseMemoryOccupationReqOutput,
ResumeMemoryOccupationReqInput,
ResumeMemoryOccupationReqOutput,
SessionParams, SessionParams,
SetInternalStateReq,
SetInternalStateReqOutput,
SlowDownReqInput,
SlowDownReqOutput,
TokenizedEmbeddingReqInput, TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
UnloadLoRAAdapterReqInput,
UnloadLoRAAdapterReqOutput,
UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqInput,
UpdateWeightFromDiskReqOutput, UpdateWeightFromDiskReqOutput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromDistributedReqOutput,
UpdateWeightsFromTensorReqInput,
UpdateWeightsFromTensorReqOutput,
) )
from sglang.srt.managers.mm_utils import TensorTransportMode from sglang.srt.managers.mm_utils import TensorTransportMode
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
from sglang.srt.managers.scheduler import is_health_check_generate_req from sglang.srt.managers.scheduler import is_health_check_generate_req
from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region
from sglang.srt.managers.tokenizer_communicator_mixin import TokenizerCommunicatorMixin
from sglang.srt.metrics.collector import TokenizerMetricsCollector from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
...@@ -177,7 +134,7 @@ class ReqState: ...@@ -177,7 +134,7 @@ class ReqState:
output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list) output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
class TokenizerManager: class TokenizerManager(TokenizerCommunicatorMixin):
"""TokenizerManager is a process that tokenizes the text.""" """TokenizerManager is a process that tokenizes the text."""
def __init__( def __init__(
...@@ -343,50 +300,6 @@ class TokenizerManager: ...@@ -343,50 +300,6 @@ class TokenizerManager:
if self.server_args.gc_warning_threshold_secs > 0.0: if self.server_args.gc_warning_threshold_secs > 0.0:
configure_gc_warning(self.server_args.gc_warning_threshold_secs) configure_gc_warning(self.server_args.gc_warning_threshold_secs)
# Communicators
self.init_weights_update_group_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.update_weights_from_distributed_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.update_weights_from_tensor_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.get_weights_by_name_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.release_memory_occupation_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.resume_memory_occupation_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.slow_down_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.flush_cache_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.clear_hicache_storage_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.profile_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.get_internal_state_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.set_internal_state_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.expert_distribution_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.update_lora_adapter_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self._result_dispatcher = TypeBasedDispatcher( self._result_dispatcher = TypeBasedDispatcher(
[ [
( (
...@@ -404,70 +317,16 @@ class TokenizerManager: ...@@ -404,70 +317,16 @@ class TokenizerManager:
UpdateWeightFromDiskReqOutput, UpdateWeightFromDiskReqOutput,
self._handle_update_weights_from_disk_req_output, self._handle_update_weights_from_disk_req_output,
), ),
(
InitWeightsUpdateGroupReqOutput,
self.init_weights_update_group_communicator.handle_recv,
),
(
UpdateWeightsFromDistributedReqOutput,
self.update_weights_from_distributed_communicator.handle_recv,
),
(
UpdateWeightsFromTensorReqOutput,
self.update_weights_from_tensor_communicator.handle_recv,
),
(
GetWeightsByNameReqOutput,
self.get_weights_by_name_communicator.handle_recv,
),
(
ReleaseMemoryOccupationReqOutput,
self.release_memory_occupation_communicator.handle_recv,
),
(
ResumeMemoryOccupationReqOutput,
self.resume_memory_occupation_communicator.handle_recv,
),
(
SlowDownReqOutput,
self.slow_down_communicator.handle_recv,
),
(
ClearHiCacheReqOutput,
self.clear_hicache_storage_communicator.handle_recv,
),
(
FlushCacheReqOutput,
self.flush_cache_communicator.handle_recv,
),
(
ProfileReqOutput,
self.profile_communicator.handle_recv,
),
( (
FreezeGCReq, FreezeGCReq,
lambda x: None, lambda x: None,
), # For handling case when scheduler skips detokenizer and forwards back to the tokenizer manager, we ignore it. ), # For handling case when scheduler skips detokenizer and forwards back to the tokenizer manager, we ignore it.
(
GetInternalStateReqOutput,
self.get_internal_state_communicator.handle_recv,
),
(
SetInternalStateReqOutput,
self.set_internal_state_communicator.handle_recv,
),
(
ExpertDistributionReqOutput,
self.expert_distribution_communicator.handle_recv,
),
(
LoRAUpdateResult,
self.update_lora_adapter_communicator.handle_recv,
),
(HealthCheckOutput, lambda x: None), (HealthCheckOutput, lambda x: None),
] ]
) )
self.init_communicators(server_args)
async def generate_request( async def generate_request(
self, self,
obj: Union[GenerateReqInput, EmbeddingReqInput], obj: Union[GenerateReqInput, EmbeddingReqInput],
...@@ -983,16 +842,6 @@ class TokenizerManager: ...@@ -983,16 +842,6 @@ class TokenizerManager:
except StopAsyncIteration: except StopAsyncIteration:
pass pass
async def flush_cache(self) -> FlushCacheReqOutput:
return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
async def clear_hicache_storage(self) -> ClearHiCacheReqOutput:
"""Clear the hierarchical cache storage."""
# Delegate to the scheduler to handle HiCacheStorage clearing
return (await self.clear_hicache_storage_communicator(ClearHiCacheReqInput()))[
0
]
def abort_request(self, rid: str = "", abort_all: bool = False): def abort_request(self, rid: str = "", abort_all: bool = False):
if not abort_all and rid not in self.rid_to_state: if not abort_all and rid not in self.rid_to_state:
return return
...@@ -1002,55 +851,6 @@ class TokenizerManager: ...@@ -1002,55 +851,6 @@ class TokenizerManager:
if self.enable_metrics: if self.enable_metrics:
self.metrics_collector.observe_one_aborted_request() self.metrics_collector.observe_one_aborted_request()
async def start_profile(
self,
output_dir: Optional[str] = None,
start_step: Optional[int] = None,
num_steps: Optional[int] = None,
activities: Optional[List[str]] = None,
with_stack: Optional[bool] = None,
record_shapes: Optional[bool] = None,
profile_by_stage: bool = False,
):
self.auto_create_handle_loop()
env_with_stack: bool = get_bool_env_var("SGLANG_PROFILE_WITH_STACK", "true")
with_stack = False if with_stack is False or env_with_stack is False else True
req = ProfileReq(
type=ProfileReqType.START_PROFILE,
output_dir=output_dir,
start_step=start_step,
num_steps=num_steps,
activities=activities,
with_stack=with_stack,
record_shapes=record_shapes,
profile_by_stage=profile_by_stage,
profile_id=str(time.time()),
)
return await self._execute_profile(req)
async def stop_profile(self):
self.auto_create_handle_loop()
req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
return await self._execute_profile(req)
async def _execute_profile(self, req: ProfileReq):
result = (await self.profile_communicator(req))[0]
if not result.success:
raise RuntimeError(result.message)
return result
async def start_expert_distribution_record(self):
self.auto_create_handle_loop()
await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD)
async def stop_expert_distribution_record(self):
self.auto_create_handle_loop()
await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD)
async def dump_expert_distribution_record(self):
self.auto_create_handle_loop()
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
async def pause_generation(self): async def pause_generation(self):
async with self.is_pause_cond: async with self.is_pause_cond:
self.is_pause = True self.is_pause = True
...@@ -1111,191 +911,6 @@ class TokenizerManager: ...@@ -1111,191 +911,6 @@ class TokenizerManager:
all_paused_requests = [r.num_paused_requests for r in result] all_paused_requests = [r.num_paused_requests for r in result]
return all_success, all_message, all_paused_requests return all_success, all_message, all_paused_requests
async def init_weights_update_group(
self,
obj: InitWeightsUpdateGroupReqInput,
request: Optional[fastapi.Request] = None,
) -> Tuple[bool, str]:
self.auto_create_handle_loop()
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for init parameter update group"
result = (await self.init_weights_update_group_communicator(obj))[0]
return result.success, result.message
async def update_weights_from_distributed(
self,
obj: UpdateWeightsFromDistributedReqInput,
request: Optional[fastapi.Request] = None,
) -> Tuple[bool, str]:
self.auto_create_handle_loop()
assert (
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
), "dp_size must be 1 or dp attention must be enabled for update weights from distributed"
if obj.abort_all_requests:
self.abort_request(abort_all=True)
# This means that weight sync
# cannot run while requests are in progress.
async with self.model_update_lock.writer_lock:
result = (await self.update_weights_from_distributed_communicator(obj))[0]
return result.success, result.message
async def update_weights_from_tensor(
self,
obj: UpdateWeightsFromTensorReqInput,
request: Optional[fastapi.Request] = None,
) -> Tuple[bool, str]:
self.auto_create_handle_loop()
assert (
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
), "dp_size must be 1 or dp attention must be enabled for update weights from tensor"
if obj.abort_all_requests:
self.abort_request(abort_all=True)
# This means that weight sync
# cannot run while requests are in progress.
async with self.model_update_lock.writer_lock:
result = (await self.update_weights_from_tensor_communicator(obj))[0]
return result.success, result.message
async def load_lora_adapter(
self,
obj: LoadLoRAAdapterReqInput,
_: Optional[fastapi.Request] = None,
) -> LoadLoRAAdapterReqOutput:
self.auto_create_handle_loop()
try:
if not self.server_args.enable_lora:
raise ValueError(
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
)
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
# with dp_size > 1.
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for dynamic lora loading"
logger.info(
"Start load Lora adapter. Lora name=%s, path=%s",
obj.lora_name,
obj.lora_path,
)
async with self.lora_update_lock:
if (
self.server_args.max_loaded_loras is not None
and self.lora_registry.num_registered_loras
>= self.server_args.max_loaded_loras
):
raise ValueError(
f"Cannot load LoRA adapter {obj.lora_name} at path {obj.lora_path}. "
f"Maximum number of loaded LoRA adapters is {self.server_args.max_loaded_loras}. "
"Please unload some LoRA adapters before loading new ones."
)
# Generate new uniquely identifiable LoRARef object.
new_adapter = LoRARef(
lora_name=obj.lora_name,
lora_path=obj.lora_path,
pinned=obj.pinned,
)
# Trigger the actual loading operation at the backend processes.
obj.lora_id = new_adapter.lora_id
result = (await self.update_lora_adapter_communicator(obj))[0]
# Register the LoRA adapter only after loading is successful.
if result.success:
await self.lora_registry.register(new_adapter)
return result
except ValueError as e:
return LoadLoRAAdapterReqOutput(
success=False,
error_message=str(e),
)
async def unload_lora_adapter(
self,
obj: UnloadLoRAAdapterReqInput,
_: Optional[fastapi.Request] = None,
) -> UnloadLoRAAdapterReqOutput:
self.auto_create_handle_loop()
try:
if not self.server_args.enable_lora:
raise ValueError(
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
)
assert (
obj.lora_name is not None
), "lora_name must be provided to unload LoRA adapter"
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
# with dp_size > 1.
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for dynamic lora loading"
logger.info(
"Start unload Lora adapter. Lora name=%s",
obj.lora_name,
)
async with self.lora_update_lock:
# Unregister the LoRA adapter from the registry to stop new requests for this adapter
# from being started.
lora_id = await self.lora_registry.unregister(obj.lora_name)
obj.lora_id = lora_id
# Initiate the actual unloading operation at the backend processes only after all
# ongoing requests using this LoRA adapter are finished.
await self.lora_registry.wait_for_unload(lora_id)
result = (await self.update_lora_adapter_communicator(obj))[0]
return result
except ValueError as e:
return UnloadLoRAAdapterReqOutput(success=False, error_message=str(e))
async def get_weights_by_name(
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
):
self.auto_create_handle_loop()
results = await self.get_weights_by_name_communicator(obj)
all_parameters = [r.parameter for r in results]
if self.server_args.dp_size == 1:
return all_parameters[0]
else:
return all_parameters
async def release_memory_occupation(
self,
obj: ReleaseMemoryOccupationReqInput,
request: Optional[fastapi.Request] = None,
):
self.auto_create_handle_loop()
await self.release_memory_occupation_communicator(obj)
async def resume_memory_occupation(
self,
obj: ResumeMemoryOccupationReqInput,
request: Optional[fastapi.Request] = None,
):
self.auto_create_handle_loop()
await self.resume_memory_occupation_communicator(obj)
async def slow_down(
self,
obj: SlowDownReqInput,
request: Optional[fastapi.Request] = None,
):
self.auto_create_handle_loop()
await self.slow_down_communicator(obj)
async def open_session( async def open_session(
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
): ):
...@@ -1320,28 +935,6 @@ class TokenizerManager: ...@@ -1320,28 +935,6 @@ class TokenizerManager:
): ):
await self.send_to_scheduler.send_pyobj(obj) await self.send_to_scheduler.send_pyobj(obj)
async def get_internal_state(self) -> List[Dict[Any, Any]]:
req = GetInternalStateReq()
responses: List[GetInternalStateReqOutput] = (
await self.get_internal_state_communicator(req)
)
# Many DP ranks
return [res.internal_state for res in responses]
async def set_internal_state(self, obj: SetInternalStateReq) -> List[bool]:
responses: List[SetInternalStateReqOutput] = (
await self.set_internal_state_communicator(obj)
)
return [res.updated for res in responses]
async def get_load(self) -> dict:
# TODO(lsyin): fake load report server
if not self.current_load_lock.locked():
async with self.current_load_lock:
internal_state = await self.get_internal_state()
self.current_load = internal_state[0]["load"]
return {"load": self.current_load}
def get_log_request_metadata(self): def get_log_request_metadata(self):
max_length = None max_length = None
skip_names = None skip_names = None
...@@ -2108,51 +1701,6 @@ class SignalHandler: ...@@ -2108,51 +1701,6 @@ class SignalHandler:
kill_process_tree(os.getpid()) kill_process_tree(os.getpid())
T = TypeVar("T")
class _Communicator(Generic[T]):
"""Note: The communicator now only run up to 1 in-flight request at any time."""
enable_multi_tokenizer = False
def __init__(self, sender, fan_out: int):
self._sender = sender
self._fan_out = fan_out
self._result_event: Optional[asyncio.Event] = None
self._result_values: Optional[List[T]] = None
self._ready_queue: Deque[asyncio.Future] = deque()
async def __call__(self, obj):
ready_event = asyncio.Event()
if self._result_event is not None or len(self._ready_queue) > 0:
self._ready_queue.append(ready_event)
await ready_event.wait()
assert self._result_event is None
assert self._result_values is None
if obj:
if _Communicator.enable_multi_tokenizer:
obj = MultiTokenizerWrapper(worker_id=os.getpid(), obj=obj)
self._sender.send_pyobj(obj)
self._result_event = asyncio.Event()
self._result_values = []
await self._result_event.wait()
result_values = self._result_values
self._result_event = self._result_values = None
if len(self._ready_queue) > 0:
self._ready_queue.popleft().set()
return result_values
def handle_recv(self, recv_obj: T):
self._result_values.append(recv_obj)
if len(self._result_values) == self._fan_out:
self._result_event.set()
# Note: request abort handling logic # Note: request abort handling logic
# We should handle all of the following cases correctly. # We should handle all of the following cases correctly.
# #
......
...@@ -6,7 +6,7 @@ from torch.distributed.device_mesh import DeviceMesh ...@@ -6,7 +6,7 @@ from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor import DTensor from torch.distributed.tensor import DTensor
from sglang.srt.entrypoints.engine import Engine from sglang.srt.entrypoints.engine import Engine
from sglang.srt.managers.tokenizer_manager import UpdateWeightsFromTensorReqInput from sglang.srt.managers.io_struct import UpdateWeightsFromTensorReqInput
from sglang.srt.model_executor.model_runner import LocalSerializedTensor from sglang.srt.model_executor.model_runner import LocalSerializedTensor
from sglang.srt.utils import MultiprocessingSerializer from sglang.srt.utils import MultiprocessingSerializer
......
...@@ -473,6 +473,10 @@ class TypeBasedDispatcher: ...@@ -473,6 +473,10 @@ class TypeBasedDispatcher:
def __init__(self, mapping: List[Tuple[Type, Callable]]): def __init__(self, mapping: List[Tuple[Type, Callable]]):
self._mapping = mapping self._mapping = mapping
def __iadd__(self, other: "TypeBasedDispatcher"):
self._mapping.extend(other._mapping)
return self
def __call__(self, obj: Any): def __call__(self, obj: Any):
for ty, fn in self._mapping: for ty, fn in self._mapping:
if isinstance(obj, ty): if isinstance(obj, ty):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment