Unverified Commit a25e8e42 authored by Shangming Cai's avatar Shangming Cai Committed by GitHub
Browse files

Move multi-tokenizer event loop to better place (#9902)


Signed-off-by: default avatarShangming Cai <csmthu@gmail.com>
parent d4a93841
......@@ -39,7 +39,6 @@ from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
configure_logger,
freeze_gc,
get_worker_ids_from_req_rids,
get_zmq_socket,
kill_itself_when_parent_died,
)
......@@ -120,39 +119,6 @@ class DetokenizerManager(MultiTokenizerMixin):
if output is not None:
self.send_to_tokenizer.send_pyobj(output)
def multi_tokenizer_manager_event_loop(self):
"""The event loop that handles requests, for multi tokenizer manager mode only"""
self.create_sockets_mapping()
while True:
recv_obj = self.recv_from_scheduler.recv_pyobj()
output = self._request_dispatcher(recv_obj)
if output is None:
continue
# Extract worker_id from rid
if isinstance(recv_obj.rids, list):
worker_ids = get_worker_ids_from_req_rids(recv_obj.rids)
else:
raise RuntimeError(
f"for tokenizer_worker_num > 1, recv_obj.rids must be a list"
)
# Send data using the corresponding socket
for i, worker_id in enumerate(worker_ids):
if isinstance(recv_obj, MultiTokenizerRegisterReq):
if self.register_tokenizer_ipc(recv_obj, worker_id):
logger.info(
f"DetokenizerManager Created ZMQ socket for worker {worker_id}"
)
continue
else:
if worker_id not in self.tokenizer_mapping:
logger.error(
f"Tokenizer Worker ID {worker_id} not registered. Check if the server Process {worker_id} is alive"
)
continue
new_output = self._handle_output_by_index(output, i)
self.tokenizer_mapping[worker_id].send_pyobj(new_output)
def trim_matched_stop(
self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
):
......
......@@ -37,11 +37,7 @@ from sglang.srt.managers.io_struct import (
)
from sglang.srt.managers.tokenizer_manager import TokenizerManager, _Communicator
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
get_worker_ids_from_req_rids,
get_zmq_socket,
kill_process_tree,
)
from sglang.srt.utils import get_zmq_socket, kill_process_tree
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
......@@ -344,6 +340,48 @@ class MultiTokenizerMixin:
new_output = output
return new_output
def get_worker_ids_from_req_rids(self, rids):
if isinstance(rids, list):
worker_ids = [int(rid.split("_")[0]) for rid in rids]
elif isinstance(rids, str):
worker_ids = [int(rids.split("_")[0])]
else:
worker_ids = []
return worker_ids
def multi_tokenizer_manager_event_loop(self):
"""The event loop that handles requests, for multi tokenizer manager mode only"""
self.create_sockets_mapping()
while True:
recv_obj = self.recv_from_scheduler.recv_pyobj()
output = self._request_dispatcher(recv_obj)
if output is None:
continue
# Extract worker_id from rid
if isinstance(recv_obj.rids, list):
worker_ids = self.get_worker_ids_from_req_rids(recv_obj.rids)
else:
raise RuntimeError(
f"for tokenizer_worker_num > 1, recv_obj.rids must be a list"
)
# Send data using the corresponding socket
for i, worker_id in enumerate(worker_ids):
if isinstance(recv_obj, MultiTokenizerRegisterReq):
if self.register_tokenizer_ipc(recv_obj, worker_id):
logger.info(
f"DetokenizerManager Created ZMQ socket for worker {worker_id}"
)
continue
else:
if worker_id not in self.tokenizer_mapping:
logger.error(
f"Tokenizer Worker ID {worker_id} not registered. Check if the server Process {worker_id} is alive"
)
continue
new_output = self._handle_output_by_index(output, i)
self.tokenizer_mapping[worker_id].send_pyobj(new_output)
def clear_tokenizer_mapping(self):
if hasattr(self, "tokenizer_mapping"):
for socket in self.tokenizer_mapping.values():
......@@ -406,7 +444,7 @@ class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin):
worker_ids = [recv_obj.worker_id]
recv_obj = recv_obj.obj
else:
worker_ids = get_worker_ids_from_req_rids(recv_obj.rids)
worker_ids = self.get_worker_ids_from_req_rids(recv_obj.rids)
if len(worker_ids) == 0:
logger.error(f"Cannot find worker_id from rids {recv_obj.rids}")
......
......@@ -2787,16 +2787,6 @@ def lru_cache_frozenset(maxsize=128):
return decorator
def get_worker_ids_from_req_rids(rids):
if isinstance(rids, list):
worker_ids = [int(rid.split("_")[0]) for rid in rids]
elif isinstance(rids, str):
worker_ids = [int(rids.split("_")[0])]
else:
worker_ids = []
return worker_ids
def get_origin_rid(rid):
return rid.split("_", 1)[1] if "_" in rid else rid
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment