Unverified Commit 305c9e8c authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

[4/N]DP refactor: support watching mode `get_load` and shortest queue strategy (#10201)

parent ca63f075
...@@ -27,7 +27,7 @@ import tempfile ...@@ -27,7 +27,7 @@ import tempfile
import threading import threading
import time import time
from http import HTTPStatus from http import HTTPStatus
from typing import Any, AsyncIterator, Callable, Dict, List, Optional from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Union
import setproctitle import setproctitle
...@@ -96,6 +96,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -96,6 +96,7 @@ from sglang.srt.managers.io_struct import (
) )
from sglang.srt.managers.multi_tokenizer_mixin import ( from sglang.srt.managers.multi_tokenizer_mixin import (
MultiTokenizerManager, MultiTokenizerManager,
MultiTokenizerRouter,
get_main_process_id, get_main_process_id,
monkey_patch_uvicorn_multiprocessing, monkey_patch_uvicorn_multiprocessing,
read_from_shared_memory, read_from_shared_memory,
...@@ -127,7 +128,9 @@ HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20)) ...@@ -127,7 +128,9 @@ HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
# Store global states # Store global states
@dataclasses.dataclass @dataclasses.dataclass
class _GlobalState: class _GlobalState:
tokenizer_manager: TokenizerManager tokenizer_manager: Union[
TokenizerManager, MultiTokenizerRouter, MultiTokenizerManager
]
template_manager: TemplateManager template_manager: TemplateManager
scheduler_info: Dict scheduler_info: Dict
......
...@@ -21,6 +21,7 @@ import struct ...@@ -21,6 +21,7 @@ import struct
import sys import sys
import threading import threading
import time import time
from collections import deque
from enum import Enum, auto from enum import Enum, auto
from multiprocessing import shared_memory from multiprocessing import shared_memory
from typing import Dict, List from typing import Dict, List
...@@ -34,6 +35,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -34,6 +35,7 @@ from sglang.srt.managers.io_struct import (
BlockReqInput, BlockReqInput,
TokenizedEmbeddingReqInput, TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
WatchLoadUpdateReq,
) )
from sglang.srt.managers.schedule_batch import Req from sglang.srt.managers.schedule_batch import Req
from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.managers.scheduler import run_scheduler_process
...@@ -46,7 +48,7 @@ from sglang.srt.utils import ( ...@@ -46,7 +48,7 @@ from sglang.srt.utils import (
get_zmq_socket, get_zmq_socket,
kill_itself_when_parent_died, kill_itself_when_parent_died,
) )
from sglang.utils import get_exception_traceback from sglang.utils import TypeBasedDispatcher, get_exception_traceback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -67,6 +69,42 @@ class LoadBalanceMethod(Enum): ...@@ -67,6 +69,42 @@ class LoadBalanceMethod(Enum):
raise ValueError(f"Invalid load balance method: {method}") from exc raise ValueError(f"Invalid load balance method: {method}") from exc
class DPBudget:
def __init__(self):
# TODO: support minimum tokens method
self.budget_queue = deque()
def update_budget(self, load_update: WatchLoadUpdateReq):
"""Update the budget queue.
Use num_reqs instead of num_waiting_reqs to balance decode running batch.
"""
loads = load_update.loads
self.budget_queue.clear()
num_reqs = [load.num_reqs for load in loads]
if not num_reqs:
return
max_num_reqs = max(num_reqs)
if all(x == max_num_reqs for x in num_reqs):
return
while any(x != num_reqs[0] for x in num_reqs):
min_load = min(num_reqs)
min_indices = [i for i, x in enumerate(num_reqs) if x == min_load]
second_min_load = min(x for x in num_reqs if x > min_load)
self.budget_queue.extend(
[loads[i].dp_rank for i in min_indices] * (second_min_load - min_load)
)
for idx in min_indices:
num_reqs[idx] = second_min_load
def dispatch(self):
if self.budget_queue:
return self.budget_queue.popleft()
return None
class DataParallelController: class DataParallelController:
"""A controller that dispatches requests to multiple data parallel workers.""" """A controller that dispatches requests to multiple data parallel workers."""
...@@ -104,6 +142,9 @@ class DataParallelController: ...@@ -104,6 +142,9 @@ class DataParallelController:
} }
self.dispatching = dispatch_lookup[self.load_balance_method] self.dispatching = dispatch_lookup[self.load_balance_method]
# Load balance budget
self.dp_budget = DPBudget()
# Launch data parallel workers # Launch data parallel workers
self.scheduler_procs = [] self.scheduler_procs = []
self.workers: List[zmq.Socket] = [None] * server_args.dp_size self.workers: List[zmq.Socket] = [None] * server_args.dp_size
...@@ -127,6 +168,31 @@ class DataParallelController: ...@@ -127,6 +168,31 @@ class DataParallelController:
self.max_req_input_len = None self.max_req_input_len = None
self.init_dispatcher()
def send_to_all_workers(self, obj):
for worker in self.workers:
worker.send_pyobj(obj)
def send_control_message(self, obj):
# Send control messages to first worker of tp group
for worker in self.workers[:: self.control_message_step]:
worker.send_pyobj(obj)
def handle_load_update_req(self, obj):
self.dp_budget.update_budget(obj)
def init_dispatcher(self):
self._request_dispatcher = TypeBasedDispatcher(
[
(TokenizedGenerateReqInput, self.dispatching),
(TokenizedEmbeddingReqInput, self.dispatching),
(BlockReqInput, self.send_to_all_workers),
(WatchLoadUpdateReq, self.handle_load_update_req),
]
)
self._request_dispatcher.add_fallback_fn(self.send_control_message)
def launch_dp_schedulers(self, server_args, port_args): def launch_dp_schedulers(self, server_args, port_args):
base_gpu_id = 0 base_gpu_id = 0
...@@ -291,10 +357,14 @@ class DataParallelController: ...@@ -291,10 +357,14 @@ class DataParallelController:
else: else:
self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req) self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req)
def shortest_queue_scheduler(self, input_requests): def shortest_queue_scheduler(self, req):
if self.maybe_external_dp_rank_routing(req): if self.maybe_external_dp_rank_routing(req):
return return
raise NotImplementedError() target_worker = self.dp_budget.dispatch()
if target_worker is None:
self.round_robin_scheduler(req)
else:
self.workers[target_worker].send_pyobj(req)
def minimum_tokens_scheduler(self, req): def minimum_tokens_scheduler(self, req):
if self.maybe_external_dp_rank_routing(req): if self.maybe_external_dp_rank_routing(req):
...@@ -333,22 +403,7 @@ class DataParallelController: ...@@ -333,22 +403,7 @@ class DataParallelController:
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK) recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
except zmq.ZMQError: except zmq.ZMQError:
break break
self._request_dispatcher(recv_req)
if isinstance(
recv_req,
(
TokenizedGenerateReqInput,
TokenizedEmbeddingReqInput,
),
):
self.dispatching(recv_req)
elif isinstance(recv_req, BlockReqInput):
for worker in self.workers:
worker.send_pyobj(recv_req)
else:
# Send other control messages to first worker of tp group
for worker in self.workers[:: self.control_message_step]:
worker.send_pyobj(recv_req)
def run_data_parallel_controller_process( def run_data_parallel_controller_process(
......
...@@ -297,7 +297,7 @@ def run_detokenizer_process( ...@@ -297,7 +297,7 @@ def run_detokenizer_process(
else: else:
manager.event_loop() manager.event_loop()
except Exception: except Exception:
manager.socket_mapping.clear_all_sockets() manager.maybe_clear_socket_mapping()
traceback = get_exception_traceback() traceback = get_exception_traceback()
logger.error(f"DetokenizerManager hit an exception: {traceback}") logger.error(f"DetokenizerManager hit an exception: {traceback}")
parent_process.send_signal(signal.SIGQUIT) parent_process.send_signal(signal.SIGQUIT)
...@@ -1374,3 +1374,21 @@ class BlockReqType(Enum): ...@@ -1374,3 +1374,21 @@ class BlockReqType(Enum):
@dataclass @dataclass
class BlockReqInput: class BlockReqInput:
type: BlockReqType type: BlockReqType
@dataclass
class GetLoadReqInput:
pass
@dataclass
class GetLoadReqOutput:
dp_rank: int
num_reqs: int
num_waiting_reqs: int
num_tokens: int
@dataclass
class WatchLoadUpdateReq:
loads: List[GetLoadReqOutput]
...@@ -354,6 +354,10 @@ class MultiHttpWorkerDetokenizerMixin: ...@@ -354,6 +354,10 @@ class MultiHttpWorkerDetokenizerMixin:
worker_ids = [] worker_ids = []
return worker_ids return worker_ids
def maybe_clear_socket_mapping(self):
if hasattr(self, "socket_mapping"):
self.socket_mapping.clear_all_sockets()
def multi_http_worker_event_loop(self): def multi_http_worker_event_loop(self):
"""The event loop that handles requests, for multi multi-http-worker mode""" """The event loop that handles requests, for multi multi-http-worker mode"""
self.socket_mapping = SocketMapping() self.socket_mapping = SocketMapping()
......
...@@ -79,6 +79,8 @@ from sglang.srt.managers.io_struct import ( ...@@ -79,6 +79,8 @@ from sglang.srt.managers.io_struct import (
FreezeGCReq, FreezeGCReq,
GetInternalStateReq, GetInternalStateReq,
GetInternalStateReqOutput, GetInternalStateReqOutput,
GetLoadReqInput,
GetLoadReqOutput,
GetWeightsByNameReqInput, GetWeightsByNameReqInput,
HealthCheckOutput, HealthCheckOutput,
InitWeightsSendGroupForRemoteInstanceReqInput, InitWeightsSendGroupForRemoteInstanceReqInput,
...@@ -577,6 +579,7 @@ class Scheduler( ...@@ -577,6 +579,7 @@ class Scheduler(
(LoadLoRAAdapterReqInput, self.load_lora_adapter), (LoadLoRAAdapterReqInput, self.load_lora_adapter),
(UnloadLoRAAdapterReqInput, self.unload_lora_adapter), (UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
(MultiTokenizerRegisterReq, self.register_multi_tokenizer), (MultiTokenizerRegisterReq, self.register_multi_tokenizer),
(GetLoadReqInput, self.get_load),
] ]
) )
...@@ -2279,39 +2282,50 @@ class Scheduler( ...@@ -2279,39 +2282,50 @@ class Scheduler(
if_success = False if_success = False
return if_success return if_success
def get_load(self): def get_load(self, recv_req: GetLoadReqInput = None) -> GetLoadReqOutput:
# TODO(lsyin): use dynamically maintained num_waiting_tokens # TODO(lsyin): use dynamically maintained num_waiting_tokens
if self.is_hybrid: if self.is_hybrid:
load_full = ( num_tokens_full = (
self.full_tokens_per_layer self.full_tokens_per_layer
- self.token_to_kv_pool_allocator.full_available_size() - self.token_to_kv_pool_allocator.full_available_size()
- self.tree_cache.full_evictable_size() - self.tree_cache.full_evictable_size()
) )
load_swa = ( num_tokens_swa = (
self.swa_tokens_per_layer self.swa_tokens_per_layer
- self.token_to_kv_pool_allocator.swa_available_size() - self.token_to_kv_pool_allocator.swa_available_size()
- self.tree_cache.swa_evictable_size() - self.tree_cache.swa_evictable_size()
) )
load = max(load_full, load_swa) num_tokens = max(num_tokens_full, num_tokens_swa)
else: else:
load = ( num_tokens = (
self.max_total_num_tokens self.max_total_num_tokens
- self.token_to_kv_pool_allocator.available_size() - self.token_to_kv_pool_allocator.available_size()
- self.tree_cache.evictable_size() - self.tree_cache.evictable_size()
) )
load += sum(len(req.origin_input_ids) for req in self.waiting_queue)
# Tokens in waiting queue, bootstrap queue, prealloc queue
num_tokens += sum(len(req.origin_input_ids) for req in self.waiting_queue)
num_waiting_reqs = len(self.waiting_queue)
if self.disaggregation_mode == DisaggregationMode.PREFILL: if self.disaggregation_mode == DisaggregationMode.PREFILL:
load += sum( num_tokens += sum(
len(req.origin_input_ids) len(req.origin_input_ids)
for req in self.disagg_prefill_bootstrap_queue.queue for req in self.disagg_prefill_bootstrap_queue.queue
) )
num_waiting_reqs += len(self.disagg_prefill_bootstrap_queue.queue)
elif self.disaggregation_mode == DisaggregationMode.DECODE: elif self.disaggregation_mode == DisaggregationMode.DECODE:
load += sum( num_tokens += sum(
len(req.req.origin_input_ids) len(req.req.origin_input_ids)
for req in self.disagg_decode_prealloc_queue.queue for req in self.disagg_decode_prealloc_queue.queue
) )
num_waiting_reqs += len(self.disagg_decode_prealloc_queue.queue)
return load return GetLoadReqOutput(
dp_rank=self.dp_rank,
num_reqs=len(self.running_batch.reqs) + num_waiting_reqs,
num_waiting_reqs=num_waiting_reqs,
num_tokens=num_tokens,
)
def get_internal_state(self, recv_req: GetInternalStateReq): def get_internal_state(self, recv_req: GetInternalStateReq):
ret = dict(global_server_args_dict) ret = dict(global_server_args_dict)
...@@ -2337,8 +2351,6 @@ class Scheduler( ...@@ -2337,8 +2351,6 @@ class Scheduler(
if RECORD_STEP_TIME: if RECORD_STEP_TIME:
ret["step_time_dict"] = self.step_time_dict ret["step_time_dict"] = self.step_time_dict
ret["load"] = self.get_load()
return GetInternalStateReqOutput(internal_state=ret) return GetInternalStateReqOutput(internal_state=ret)
def set_internal_state(self, recv_req: SetInternalStateReq): def set_internal_state(self, recv_req: SetInternalStateReq):
......
...@@ -279,7 +279,7 @@ class SchedulerMetricsMixin: ...@@ -279,7 +279,7 @@ class SchedulerMetricsMixin:
self.server_args.load_balance_method == "minimum_tokens" self.server_args.load_balance_method == "minimum_tokens"
and self.forward_ct % 40 == 0 and self.forward_ct % 40 == 0
): ):
holding_tokens = self.get_load() holding_tokens = self.get_load().num_tokens
new_recv_dp_balance_id_list, holding_token_list = ( new_recv_dp_balance_id_list, holding_token_list = (
self.gather_dp_balance_info(holding_tokens) self.gather_dp_balance_info(holding_tokens)
......
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import copy
import logging import logging
import os import os
import time import time
...@@ -18,6 +19,7 @@ from typing import ( ...@@ -18,6 +19,7 @@ from typing import (
) )
import fastapi import fastapi
import zmq
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
ClearHiCacheReqInput, ClearHiCacheReqInput,
...@@ -28,6 +30,8 @@ from sglang.srt.managers.io_struct import ( ...@@ -28,6 +30,8 @@ from sglang.srt.managers.io_struct import (
FlushCacheReqOutput, FlushCacheReqOutput,
GetInternalStateReq, GetInternalStateReq,
GetInternalStateReqOutput, GetInternalStateReqOutput,
GetLoadReqInput,
GetLoadReqOutput,
GetWeightsByNameReqInput, GetWeightsByNameReqInput,
GetWeightsByNameReqOutput, GetWeightsByNameReqOutput,
InitWeightsSendGroupForRemoteInstanceReqInput, InitWeightsSendGroupForRemoteInstanceReqInput,
...@@ -75,14 +79,17 @@ class _Communicator(Generic[T]): ...@@ -75,14 +79,17 @@ class _Communicator(Generic[T]):
enable_multi_tokenizer = False enable_multi_tokenizer = False
def __init__(self, sender, fan_out: int): def __init__(self, sender: zmq.Socket, fan_out: int, mode="queueing"):
self._sender = sender self._sender = sender
self._fan_out = fan_out self._fan_out = fan_out
self._mode = mode
self._result_event: Optional[asyncio.Event] = None self._result_event: Optional[asyncio.Event] = None
self._result_values: Optional[List[T]] = None self._result_values: Optional[List[T]] = None
self._ready_queue: Deque[asyncio.Future] = deque() self._ready_queue: Deque[asyncio.Future] = deque()
async def __call__(self, obj): assert mode in ["queueing", "watching"]
async def queueing_call(self, obj: T):
ready_event = asyncio.Event() ready_event = asyncio.Event()
if self._result_event is not None or len(self._ready_queue) > 0: if self._result_event is not None or len(self._ready_queue) > 0:
self._ready_queue.append(ready_event) self._ready_queue.append(ready_event)
...@@ -106,6 +113,28 @@ class _Communicator(Generic[T]): ...@@ -106,6 +113,28 @@ class _Communicator(Generic[T]):
return result_values return result_values
async def watching_call(self, obj):
if self._result_event is None:
assert self._result_values is None
self._result_values = []
self._result_event = asyncio.Event()
if obj:
if _Communicator.enable_multi_tokenizer:
obj = MultiTokenizerWrapper(worker_id=os.getpid(), obj=obj)
self._sender.send_pyobj(obj)
await self._result_event.wait()
result_values = copy.deepcopy(self._result_values)
self._result_event = self._result_values = None
return result_values
async def __call__(self, obj):
if self._mode == "queueing":
return await self.queueing_call(obj)
else:
return await self.watching_call(obj)
def handle_recv(self, recv_obj: T): def handle_recv(self, recv_obj: T):
self._result_values.append(recv_obj) self._result_values.append(recv_obj)
if len(self._result_values) == self._fan_out: if len(self._result_values) == self._fan_out:
...@@ -165,6 +194,9 @@ class TokenizerCommunicatorMixin: ...@@ -165,6 +194,9 @@ class TokenizerCommunicatorMixin:
self.update_lora_adapter_communicator = _Communicator( self.update_lora_adapter_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size self.send_to_scheduler, server_args.dp_size
) )
self.get_load_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size, mode="watching"
)
self._result_dispatcher += self._get_communicator_dispatcher() self._result_dispatcher += self._get_communicator_dispatcher()
...@@ -235,6 +267,10 @@ class TokenizerCommunicatorMixin: ...@@ -235,6 +267,10 @@ class TokenizerCommunicatorMixin:
LoRAUpdateResult, LoRAUpdateResult,
self.update_lora_adapter_communicator.handle_recv, self.update_lora_adapter_communicator.handle_recv,
), ),
(
GetLoadReqOutput,
self.get_load_communicator.handle_recv,
),
] ]
) )
...@@ -528,10 +564,6 @@ class TokenizerCommunicatorMixin: ...@@ -528,10 +564,6 @@ class TokenizerCommunicatorMixin:
) )
return [res.updated for res in responses] return [res.updated for res in responses]
async def get_load(self: TokenizerManager) -> dict: async def get_load(self: TokenizerManager) -> List[GetLoadReqOutput]:
# TODO(lsyin): fake load report server req = GetLoadReqInput()
if not self.current_load_lock.locked(): return await self.get_load_communicator(req)
async with self.current_load_lock:
internal_state = await self.get_internal_state()
self.current_load = internal_state[0]["load"]
return {"load": self.current_load}
...@@ -64,6 +64,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -64,6 +64,7 @@ from sglang.srt.managers.io_struct import (
EmbeddingReqInput, EmbeddingReqInput,
FreezeGCReq, FreezeGCReq,
GenerateReqInput, GenerateReqInput,
GetLoadReqInput,
HealthCheckOutput, HealthCheckOutput,
MultiTokenizerWrapper, MultiTokenizerWrapper,
OpenSessionReqInput, OpenSessionReqInput,
...@@ -73,6 +74,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -73,6 +74,7 @@ from sglang.srt.managers.io_struct import (
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqInput,
UpdateWeightFromDiskReqOutput, UpdateWeightFromDiskReqOutput,
WatchLoadUpdateReq,
) )
from sglang.srt.managers.mm_utils import TensorTransportMode from sglang.srt.managers.mm_utils import TensorTransportMode
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
...@@ -1240,6 +1242,9 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -1240,6 +1242,9 @@ class TokenizerManager(TokenizerCommunicatorMixin):
self.asyncio_tasks.add( self.asyncio_tasks.add(
loop.create_task(print_exception_wrapper(self.sigterm_watchdog)) loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
) )
self.asyncio_tasks.add(
loop.create_task(print_exception_wrapper(self.watch_load_thread))
)
def dump_requests_before_crash(self): def dump_requests_before_crash(self):
if self.crash_dump_performed: if self.crash_dump_performed:
...@@ -1844,6 +1849,20 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -1844,6 +1849,20 @@ class TokenizerManager(TokenizerCommunicatorMixin):
return scores return scores
async def watch_load_thread(self):
# Only for dp_controller when dp_size > 1
if (
self.server_args.dp_size == 1
or self.server_args.load_balance_method == "round_robin"
):
return
while True:
await asyncio.sleep(self.server_args.load_watch_interval)
loads = await self.get_load_communicator(GetLoadReqInput())
load_udpate_req = WatchLoadUpdateReq(loads=loads)
self.send_to_scheduler.send_pyobj(load_udpate_req)
class ServerStatus(Enum): class ServerStatus(Enum):
Up = "Up" Up = "Up"
......
...@@ -233,6 +233,7 @@ class ServerArgs: ...@@ -233,6 +233,7 @@ class ServerArgs:
# Data parallelism # Data parallelism
dp_size: int = 1 dp_size: int = 1
load_balance_method: str = "round_robin" load_balance_method: str = "round_robin"
load_watch_interval: float = 0.1
# FIXME: remove this after dp rank scheduling is fully supported with PD-Disaggregation # FIXME: remove this after dp rank scheduling is fully supported with PD-Disaggregation
prefill_round_robin_balance: bool = False prefill_round_robin_balance: bool = False
...@@ -663,6 +664,7 @@ class ServerArgs: ...@@ -663,6 +664,7 @@ class ServerArgs:
if self.dp_size == 1: if self.dp_size == 1:
self.enable_dp_attention = False self.enable_dp_attention = False
self.enable_dp_lm_head = False
# Data parallelism attention # Data parallelism attention
if self.enable_dp_attention: if self.enable_dp_attention:
...@@ -1488,6 +1490,12 @@ class ServerArgs: ...@@ -1488,6 +1490,12 @@ class ServerArgs:
"minimum_tokens", "minimum_tokens",
], ],
) )
parser.add_argument(
"--load-watch-interval",
type=float,
default=ServerArgs.load_watch_interval,
help="The interval of load watching in seconds.",
)
parser.add_argument( parser.add_argument(
"--prefill-round-robin-balance", "--prefill-round-robin-balance",
default=ServerArgs.prefill_round_robin_balance, default=ServerArgs.prefill_round_robin_balance,
......
...@@ -1160,7 +1160,7 @@ def pytorch_profile(name, func, *args, data_size=-1): ...@@ -1160,7 +1160,7 @@ def pytorch_profile(name, func, *args, data_size=-1):
def get_zmq_socket( def get_zmq_socket(
context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool
): ) -> zmq.Socket:
mem = psutil.virtual_memory() mem = psutil.virtual_memory()
total_mem = mem.total / 1024**3 total_mem = mem.total / 1024**3
available_mem = mem.available / 1024**3 available_mem = mem.available / 1024**3
......
...@@ -472,6 +472,10 @@ def wait_for_server(base_url: str, timeout: int = None) -> None: ...@@ -472,6 +472,10 @@ def wait_for_server(base_url: str, timeout: int = None) -> None:
class TypeBasedDispatcher: class TypeBasedDispatcher:
def __init__(self, mapping: List[Tuple[Type, Callable]]): def __init__(self, mapping: List[Tuple[Type, Callable]]):
self._mapping = mapping self._mapping = mapping
self._fallback_fn = None
def add_fallback_fn(self, fallback_fn: Callable):
self._fallback_fn = fallback_fn
def __iadd__(self, other: "TypeBasedDispatcher"): def __iadd__(self, other: "TypeBasedDispatcher"):
self._mapping.extend(other._mapping) self._mapping.extend(other._mapping)
...@@ -481,6 +485,9 @@ class TypeBasedDispatcher: ...@@ -481,6 +485,9 @@ class TypeBasedDispatcher:
for ty, fn in self._mapping: for ty, fn in self._mapping:
if isinstance(obj, ty): if isinstance(obj, ty):
return fn(obj) return fn(obj)
if self._fallback_fn is not None:
return self._fallback_fn(obj)
raise ValueError(f"Invalid object: {obj}") raise ValueError(f"Invalid object: {obj}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment