Unverified Commit e719bb0e authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

[1/2] Refactor multi-tokenizer manager (#10074)

parent 06724683
...@@ -704,6 +704,24 @@ def _set_envs_and_config(server_args: ServerArgs): ...@@ -704,6 +704,24 @@ def _set_envs_and_config(server_args: ServerArgs):
mp.set_start_method("spawn", force=True) mp.set_start_method("spawn", force=True)
def _init_tokenizer_manager(
server_args: ServerArgs, port_args: PortArgs
) -> TokenizerManager:
# Launch tokenizer process
tokenizer_manager = TokenizerManager(server_args, port_args)
# Initialize templates
template_manager = TemplateManager()
template_manager.initialize_templates(
tokenizer_manager=tokenizer_manager,
model_path=server_args.model_path,
chat_template=server_args.chat_template,
completion_template=server_args.completion_template,
)
return tokenizer_manager, template_manager
def _launch_subprocesses( def _launch_subprocesses(
server_args: ServerArgs, port_args: Optional[PortArgs] = None server_args: ServerArgs, port_args: Optional[PortArgs] = None
) -> Tuple[TokenizerManager, TemplateManager, Dict]: ) -> Tuple[TokenizerManager, TemplateManager, Dict]:
...@@ -816,23 +834,15 @@ def _launch_subprocesses( ...@@ -816,23 +834,15 @@ def _launch_subprocesses(
), ),
) )
detoken_proc.start() detoken_proc.start()
# Init tokenizer manager first, as the bootstrap server is initialized here
if server_args.tokenizer_worker_num > 1: if server_args.tokenizer_worker_num > 1:
# Launch multi-tokenizer router # Launch multi-tokenizer router
tokenizer_manager = MultiTokenizerRouter(server_args, port_args) tokenizer_manager = MultiTokenizerRouter(server_args, port_args)
# Initialize templates
template_manager = None template_manager = None
else: else:
# Launch tokenizer process tokenizer_manager, template_manager = _init_tokenizer_manager(
tokenizer_manager = TokenizerManager(server_args, port_args) server_args, port_args
# Initialize templates
template_manager = TemplateManager()
template_manager.initialize_templates(
tokenizer_manager=tokenizer_manager,
model_path=server_args.model_path,
chat_template=server_args.chat_template,
completion_template=server_args.completion_template,
) )
# Wait for the model to finish loading # Wait for the model to finish loading
...@@ -856,5 +866,7 @@ def _launch_subprocesses( ...@@ -856,5 +866,7 @@ def _launch_subprocesses(
# Assume all schedulers have the same scheduler_info # Assume all schedulers have the same scheduler_info
scheduler_info = scheduler_infos[0] scheduler_info = scheduler_infos[0]
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"] tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
return tokenizer_manager, template_manager, scheduler_info return tokenizer_manager, template_manager, scheduler_info
...@@ -92,7 +92,6 @@ from sglang.srt.managers.io_struct import ( ...@@ -92,7 +92,6 @@ from sglang.srt.managers.io_struct import (
) )
from sglang.srt.managers.multi_tokenizer_mixin import ( from sglang.srt.managers.multi_tokenizer_mixin import (
MultiTokenizerManager, MultiTokenizerManager,
deserialize_data,
get_main_process_id, get_main_process_id,
read_from_shared_memory, read_from_shared_memory,
write_data_for_multi_tokenizer, write_data_for_multi_tokenizer,
...@@ -136,21 +135,6 @@ def set_global_state(global_state: _GlobalState): ...@@ -136,21 +135,6 @@ def set_global_state(global_state: _GlobalState):
_global_state = global_state _global_state = global_state
# Function to set up all middlewares for multi-tokenizer compatibility
def setup_middlewares(api_key: Optional[str], enable_metrics: bool):
"""Setup all middlewares for both single and multi-process modes"""
worker_pid = os.getpid()
if api_key:
add_api_key_middleware(app, api_key)
logger.info(f"Worker {worker_pid} added API key middleware")
if enable_metrics:
add_prometheus_middleware(app)
enable_func_timer()
logger.info(f"Worker {worker_pid} added prometheus middleware")
async def init_multi_tokenizer() -> ServerArgs: async def init_multi_tokenizer() -> ServerArgs:
"""Read args information from shm and init tokenizer manager for current process""" """Read args information from shm and init tokenizer manager for current process"""
pid = os.getpid() pid = os.getpid()
...@@ -158,11 +142,15 @@ async def init_multi_tokenizer() -> ServerArgs: ...@@ -158,11 +142,15 @@ async def init_multi_tokenizer() -> ServerArgs:
logger.info(f"current worker_id: {pid}, main processID: {main_pid}") logger.info(f"current worker_id: {pid}, main processID: {main_pid}")
# Read configuration from shared memory # Read configuration from shared memory
port_args_data = read_from_shared_memory(f"port_args_{main_pid}") port_args, server_args, scheduler_info = read_from_shared_memory(
server_args_data = read_from_shared_memory(f"server_args_{main_pid}") f"multi_tokenizer_args_{main_pid}"
scheduler_info_data = read_from_shared_memory(f"scheduler_info_{main_pid}") )
port_args, server_args = deserialize_data(port_args_data, server_args_data) server_args: ServerArgs
scheduler_info = scheduler_info_data
# API key authentication is not supported in multi-tokenizer mode
assert (
server_args.api_key is None
), "API key is not supported in multi-tokenizer mode"
port_args.tokenizer_ipc_name = ( port_args.tokenizer_ipc_name = (
f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}" f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}"
...@@ -193,13 +181,17 @@ async def init_multi_tokenizer() -> ServerArgs: ...@@ -193,13 +181,17 @@ async def init_multi_tokenizer() -> ServerArgs:
@asynccontextmanager @asynccontextmanager
async def lifespan(fast_api_app: FastAPI): async def lifespan(fast_api_app: FastAPI):
server_args = getattr(fast_api_app, "server_args", None) if not getattr(fast_api_app, "is_single_tokenizer_mode", False):
if server_args is None:
# Initialize multi-tokenizer support for worker processes # Initialize multi-tokenizer support for worker processes
fast_api_app.server_args = await init_multi_tokenizer() fast_api_app.server_args: ServerArgs = await init_multi_tokenizer()
setup_middlewares(
fast_api_app.server_args.api_key, fast_api_app.server_args.enable_metrics # only metrics middleware is supported in multi-tokenizer mode
) worker_pid = os.getpid()
if fast_api_app.server_args.enable_metrics:
add_prometheus_middleware(app)
enable_func_timer()
logger.info(f"Worker {worker_pid} added prometheus middleware")
fast_api_app.warmup_thread = threading.Thread( fast_api_app.warmup_thread = threading.Thread(
target=_wait_and_warmup, target=_wait_and_warmup,
args=( args=(
...@@ -1187,13 +1179,11 @@ def launch_server( ...@@ -1187,13 +1179,11 @@ def launch_server(
) )
if server_args.tokenizer_worker_num > 1: if server_args.tokenizer_worker_num > 1:
port_args_shm, server_args_shm, scheduler_info_shm = ( multi_tokenizer_args_shm = write_data_for_multi_tokenizer(
write_data_for_multi_tokenizer(
port_args, port_args,
server_args, server_args,
scheduler_info, scheduler_info,
) )
)
else: else:
# Add api key authorization # Add api key authorization
if server_args.api_key: if server_args.api_key:
...@@ -1239,6 +1229,7 @@ def launch_server( ...@@ -1239,6 +1229,7 @@ def launch_server(
workers=server_args.tokenizer_worker_num, workers=server_args.tokenizer_worker_num,
) )
else: else:
app.is_single_tokenizer_mode = True
uvicorn.run( uvicorn.run(
app, app,
host=server_args.host, host=server_args.host,
...@@ -1249,10 +1240,8 @@ def launch_server( ...@@ -1249,10 +1240,8 @@ def launch_server(
) )
finally: finally:
if server_args.tokenizer_worker_num > 1: if server_args.tokenizer_worker_num > 1:
port_args_shm.unlink() multi_tokenizer_args_shm.unlink()
server_args_shm.unlink() _global_state.tokenizer_manager.socket_mapping.clear_all_sockets()
scheduler_info_shm.unlink()
_global_state.tokenizer_manager.clear_tokenizer_mapping()
else: else:
warmup_thread.join() warmup_thread.join()
......
...@@ -34,7 +34,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -34,7 +34,7 @@ from sglang.srt.managers.io_struct import (
FreezeGCReq, FreezeGCReq,
MultiTokenizerRegisterReq, MultiTokenizerRegisterReq,
) )
from sglang.srt.managers.multi_tokenizer_mixin import MultiTokenizerMixin from sglang.srt.managers.multi_tokenizer_mixin import MultiHttpWorkerDetokenizerMixin
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
configure_logger, configure_logger,
...@@ -69,7 +69,7 @@ class DecodeStatus: ...@@ -69,7 +69,7 @@ class DecodeStatus:
sent_offset: int = 0 sent_offset: int = 0
class DetokenizerManager(MultiTokenizerMixin): class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
"""DetokenizerManager is a process that detokenizes the token ids.""" """DetokenizerManager is a process that detokenizes the token ids."""
def __init__( def __init__(
...@@ -289,11 +289,11 @@ def run_detokenizer_process( ...@@ -289,11 +289,11 @@ def run_detokenizer_process(
try: try:
manager = DetokenizerManager(server_args, port_args) manager = DetokenizerManager(server_args, port_args)
if server_args.tokenizer_worker_num > 1: if server_args.tokenizer_worker_num > 1:
manager.multi_tokenizer_manager_event_loop() manager.multi_http_worker_event_loop()
else: else:
manager.event_loop() manager.event_loop()
except Exception: except Exception:
manager.clear_tokenizer_mapping() manager.socket_mapping.clear_all_sockets()
traceback = get_exception_traceback() traceback = get_exception_traceback()
logger.error(f"DetokenizerManager hit an exception: {traceback}") logger.error(f"DetokenizerManager hit an exception: {traceback}")
parent_process.send_signal(signal.SIGQUIT) parent_process.send_signal(signal.SIGQUIT)
"""Start bootstrap/kv-store-related server"""
import os
from typing import Type
from sglang.srt.disaggregation.base import BaseKVBootstrapServer
from sglang.srt.disaggregation.utils import (
DisaggregationMode,
KVClassType,
TransferBackend,
get_kv_class,
)
from sglang.srt.server_args import ServerArgs
def start_disagg_service(
server_args: ServerArgs,
):
# Start kv boostrap server on prefill
disagg_mode = DisaggregationMode(server_args.disaggregation_mode)
transfer_backend = TransferBackend(server_args.disaggregation_transfer_backend)
if disagg_mode == DisaggregationMode.PREFILL:
# only start bootstrap server on prefill tm
kv_bootstrap_server_class: Type[BaseKVBootstrapServer] = get_kv_class(
transfer_backend, KVClassType.BOOTSTRAP_SERVER
)
bootstrap_server: BaseKVBootstrapServer = kv_bootstrap_server_class(
host=server_args.host,
port=server_args.disaggregation_bootstrap_port,
)
is_create_store = (
server_args.node_rank == 0 and transfer_backend == TransferBackend.ASCEND
)
if is_create_store:
try:
from mf_adapter import create_config_store
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
create_config_store(ascend_url)
except Exception as e:
error_message = f"Failed create mf store, invalid ascend_url."
error_message += f" With exception {e}"
raise error_message
return bootstrap_server
...@@ -13,21 +13,21 @@ ...@@ -13,21 +13,21 @@
# ============================================================================== # ==============================================================================
"""MultiTokenizerMixin is a class that provides nesscary methods for MultiTokenizerManager and DetokenizerManager.""" """MultiTokenizerMixin is a class that provides nesscary methods for MultiTokenizerManager and DetokenizerManager."""
import asyncio import asyncio
import dataclasses
import json
import logging import logging
import multiprocessing as multiprocessing import multiprocessing as multiprocessing
import os import os
import pickle
import sys import sys
import threading import threading
from multiprocessing import shared_memory from multiprocessing import shared_memory
from typing import Dict from typing import Any, Dict
import setproctitle import setproctitle
import zmq import zmq
import zmq.asyncio import zmq.asyncio
from sglang.srt.disaggregation.utils import DisaggregationMode, TransferBackend from sglang.srt.disaggregation.utils import DisaggregationMode, TransferBackend
from sglang.srt.managers.disagg_service import start_disagg_service
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
BatchEmbeddingOut, BatchEmbeddingOut,
BatchMultimodalOut, BatchMultimodalOut,
...@@ -44,44 +44,42 @@ from sglang.utils import get_exception_traceback ...@@ -44,44 +44,42 @@ from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class MultiTokenizerMixin: class SocketMapping:
"""Mixin class for MultiTokenizerManager and DetokenizerManager""" def __init__(self):
def create_sockets_mapping(self):
if not hasattr(self, "tokenizer_mapping"):
self.tokenizer_mapping = {}
# Create ZMQ context if needed
if not hasattr(self, "_zmq_context"):
self._zmq_context = zmq.Context() self._zmq_context = zmq.Context()
self._mapping: Dict[str, zmq.Socket] = {}
def clear_all_sockets(self):
for socket in self._mapping.values():
socket.close()
self._mapping.clear()
def init_tokenizer_mapping( def register_ipc_mapping(
self, recv_obj: MultiTokenizerRegisterReq, worker_id: str self, recv_obj: MultiTokenizerRegisterReq, worker_id: str, is_tokenizer: bool
): ):
"""init tokenizer mapping from register request""" type_str = "tokenizer" if is_tokenizer else "detokenizer"
ipc_name = recv_obj.ipc_name if worker_id in self._mapping:
worker_id_int = int(worker_id) logger.warning(
f"{type_str} already registered with worker {worker_id}, skipping..."
if worker_id_int not in self.tokenizer_mapping: )
socket = get_zmq_socket(self._zmq_context, zmq.PUSH, ipc_name, False) return
self.tokenizer_mapping[worker_id_int] = socket logger.info(
self.tokenizer_mapping[worker_id_int].send_pyobj(recv_obj) f"{type_str} not registered with worker {worker_id}, registering..."
return True )
else: socket = get_zmq_socket(self._zmq_context, zmq.PUSH, recv_obj.ipc_name, False)
return False self._mapping[worker_id] = socket
self._mapping[worker_id].send_pyobj(recv_obj)
def register_tokenizer_ipc(self, recv_obj, worker_id): def send_output(self, worker_id: str, output: Any):
if worker_id not in self.tokenizer_mapping: if worker_id not in self._mapping:
# register the worker if not already done
if isinstance(recv_obj, MultiTokenizerRegisterReq):
return self.init_tokenizer_mapping(recv_obj, worker_id)
else:
logger.error( logger.error(
f"Worker {worker_id} not registered and not found in tokenizer mapping . " f"worker ID {worker_id} not registered. Check if the server Process is alive"
"Please ensure the worker is registered correctly."
) )
return False return
self._mapping[worker_id].send_pyobj(output)
def _handle_output_by_index(self, output, i): def _handle_output_by_index(output, i):
"""NOTE: A maintainable method is better here.""" """NOTE: A maintainable method is better here."""
if isinstance(output, BatchTokenIDOut): if isinstance(output, BatchTokenIDOut):
new_output = BatchTokenIDOut( new_output = BatchTokenIDOut(
...@@ -94,9 +92,7 @@ class MultiTokenizerMixin: ...@@ -94,9 +92,7 @@ class MultiTokenizerMixin:
decoded_texts=( decoded_texts=(
[output.decoded_texts[i]] if len(output.decoded_texts) > i else None [output.decoded_texts[i]] if len(output.decoded_texts) > i else None
), ),
decode_ids=( decode_ids=([output.decode_ids[i]] if len(output.decode_ids) > i else None),
[output.decode_ids[i]] if len(output.decode_ids) > i else None
),
read_offsets=( read_offsets=(
[output.read_offsets[i]] if len(output.read_offsets) > i else None [output.read_offsets[i]] if len(output.read_offsets) > i else None
), ),
...@@ -130,9 +126,7 @@ class MultiTokenizerMixin: ...@@ -130,9 +126,7 @@ class MultiTokenizerMixin:
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None [output.cached_tokens[i]] if len(output.cached_tokens) > i else None
), ),
spec_verify_ct=( spec_verify_ct=(
[output.spec_verify_ct[i]] [output.spec_verify_ct[i]] if len(output.spec_verify_ct) > i else None
if len(output.spec_verify_ct) > i
else None
), ),
input_token_logprobs_val=( input_token_logprobs_val=(
[output.input_token_logprobs_val[i]] [output.input_token_logprobs_val[i]]
...@@ -208,9 +202,7 @@ class MultiTokenizerMixin: ...@@ -208,9 +202,7 @@ class MultiTokenizerMixin:
if len(output.finished_reasons) > i if len(output.finished_reasons) > i
else None else None
), ),
embeddings=( embeddings=([output.embeddings[i]] if len(output.embeddings) > i else None),
[output.embeddings[i]] if len(output.embeddings) > i else None
),
prompt_tokens=( prompt_tokens=(
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None [output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
), ),
...@@ -246,9 +238,7 @@ class MultiTokenizerMixin: ...@@ -246,9 +238,7 @@ class MultiTokenizerMixin:
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None [output.cached_tokens[i]] if len(output.cached_tokens) > i else None
), ),
spec_verify_ct=( spec_verify_ct=(
[output.spec_verify_ct[i]] [output.spec_verify_ct[i]] if len(output.spec_verify_ct) > i else None
if len(output.spec_verify_ct) > i
else None
), ),
input_token_logprobs_val=( input_token_logprobs_val=(
[output.input_token_logprobs_val[i]] [output.input_token_logprobs_val[i]]
...@@ -341,6 +331,10 @@ class MultiTokenizerMixin: ...@@ -341,6 +331,10 @@ class MultiTokenizerMixin:
new_output = output new_output = output
return new_output return new_output
class MultiHttpWorkerDetokenizerMixin:
"""Mixin class for MultiTokenizerManager and DetokenizerManager"""
def get_worker_ids_from_req_rids(self, rids): def get_worker_ids_from_req_rids(self, rids):
if isinstance(rids, list): if isinstance(rids, list):
worker_ids = [int(rid.split("_")[0]) for rid in rids] worker_ids = [int(rid.split("_")[0]) for rid in rids]
...@@ -350,9 +344,9 @@ class MultiTokenizerMixin: ...@@ -350,9 +344,9 @@ class MultiTokenizerMixin:
worker_ids = [] worker_ids = []
return worker_ids return worker_ids
def multi_tokenizer_manager_event_loop(self): def multi_http_worker_event_loop(self):
"""The event loop that handles requests, for multi tokenizer manager mode only""" """The event loop that handles requests, for multi multi-http-worker mode"""
self.create_sockets_mapping() self.socket_mapping = SocketMapping()
while True: while True:
recv_obj = self.recv_from_scheduler.recv_pyobj() recv_obj = self.recv_from_scheduler.recv_pyobj()
output = self._request_dispatcher(recv_obj) output = self._request_dispatcher(recv_obj)
...@@ -369,31 +363,15 @@ class MultiTokenizerMixin: ...@@ -369,31 +363,15 @@ class MultiTokenizerMixin:
# Send data using the corresponding socket # Send data using the corresponding socket
for i, worker_id in enumerate(worker_ids): for i, worker_id in enumerate(worker_ids):
if isinstance(recv_obj, MultiTokenizerRegisterReq): if isinstance(recv_obj, MultiTokenizerRegisterReq):
if self.register_tokenizer_ipc(recv_obj, worker_id): self.socket_mapping.register_ipc_mapping(
logger.info( recv_obj, worker_id, is_tokenizer=False
f"DetokenizerManager Created ZMQ socket for worker {worker_id}"
) )
continue
else: else:
if worker_id not in self.tokenizer_mapping: new_output = _handle_output_by_index(output, i)
logger.error( self.socket_mapping.send_output(worker_id, new_output)
f"Tokenizer Worker ID {worker_id} not registered. Check if the server Process {worker_id} is alive"
)
continue
new_output = self._handle_output_by_index(output, i)
self.tokenizer_mapping[worker_id].send_pyobj(new_output)
def clear_tokenizer_mapping(self):
if hasattr(self, "tokenizer_mapping"):
for socket in self.tokenizer_mapping.values():
try:
socket.close()
except Exception as e:
logger.warning(f"Failed to close socket: {e}")
self.tokenizer_mapping.clear()
class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin): class MultiTokenizerRouter:
"""A router to receive requests from MultiTokenizerManager""" """A router to receive requests from MultiTokenizerManager"""
def __init__( def __init__(
...@@ -422,7 +400,7 @@ class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin): ...@@ -422,7 +400,7 @@ class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin):
self._handle_task = asyncio.run_coroutine_threadsafe( self._handle_task = asyncio.run_coroutine_threadsafe(
print_exception_wrapper(self.handle_loop), self._loop print_exception_wrapper(self.handle_loop), self._loop
) )
self.init_disaggregation() self.disaggregation_bootstrap_server = start_disagg_service(self.server_args)
def _run_loop(self): def _run_loop(self):
self._loop.run_forever() self._loop.run_forever()
...@@ -434,7 +412,7 @@ class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin): ...@@ -434,7 +412,7 @@ class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin):
async def handle_loop(self): async def handle_loop(self):
# special reqs will recv from scheduler, need to route to right worker # special reqs will recv from scheduler, need to route to right worker
self.create_sockets_mapping() self.socket_mapping = SocketMapping()
while True: while True:
recv_obj = await self.recv_from_detokenizer.recv_pyobj() recv_obj = await self.recv_from_detokenizer.recv_pyobj()
await self._distribute_result_to_workers(recv_obj) await self._distribute_result_to_workers(recv_obj)
...@@ -454,22 +432,15 @@ class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin): ...@@ -454,22 +432,15 @@ class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin):
# Distribute result to each worker # Distribute result to each worker
for i, worker_id in enumerate(worker_ids): for i, worker_id in enumerate(worker_ids):
if isinstance(recv_obj, MultiTokenizerRegisterReq): if isinstance(recv_obj, MultiTokenizerRegisterReq):
if self.register_tokenizer_ipc(recv_obj, worker_id): self.socket_mapping.register_ipc_mapping(
logger.info( recv_obj, worker_id, is_tokenizer=True
f"MultiTokenizerRouter Created ZMQ socket for worker {worker_id}"
) )
continue
else: else:
if worker_id not in self.tokenizer_mapping: new_recv_obj = _handle_output_by_index(recv_obj, i)
logger.error( self.socket_mapping.send_output(worker_id, new_recv_obj)
f"Tokenizer Worker ID {worker_id} not registered. Check if the server Process {worker_id} is alive"
)
continue
new_recv_obj = self._handle_output_by_index(recv_obj, i)
self.tokenizer_mapping[worker_id].send_pyobj(new_recv_obj)
class MultiTokenizerManager(TokenizerManager, MultiTokenizerMixin): class MultiTokenizerManager(TokenizerManager):
"""Multi Process Tokenizer Manager that tokenizes the text.""" """Multi Process Tokenizer Manager that tokenizes the text."""
def __init__( def __init__(
...@@ -535,42 +506,14 @@ async def print_exception_wrapper(func): ...@@ -535,42 +506,14 @@ async def print_exception_wrapper(func):
sys.exit(1) sys.exit(1)
def serialize_port_args(port_args: PortArgs) -> dict: def get_main_process_id() -> int:
"""Serialize PortArgs into a shareable dictionary""" """Get the main process ID"""
return { return multiprocessing.current_process()._parent_pid
"tokenizer_ipc_name": port_args.tokenizer_ipc_name,
"scheduler_input_ipc_name": port_args.scheduler_input_ipc_name,
"detokenizer_ipc_name": port_args.detokenizer_ipc_name,
"nccl_port": port_args.nccl_port,
"rpc_ipc_name": port_args.rpc_ipc_name,
"metrics_ipc_name": port_args.metrics_ipc_name,
"tokenizer_worker_ipc_name": port_args.tokenizer_worker_ipc_name,
}
def deserialize_data(port_args: dict, server_args: dict):
"""Deserialize data from shared dictionaries"""
return PortArgs(**port_args), ServerArgs(**server_args)
def serialize_server_args(server_args: ServerArgs) -> dict:
"""Serialize ServerArgs into a shareable dictionary"""
return dataclasses.asdict(server_args)
def serialize_scheduler_info(scheduler_info: Dict) -> dict:
"""Serialize scheduler_info into a shareable dictionary"""
return scheduler_info
def deserialize_scheduler_info(data: dict) -> Dict:
"""Deserialize scheduler_info from a shared dictionary"""
return data
def write_to_shared_memory(data: dict, name: str) -> shared_memory.SharedMemory: def write_to_shared_memory(obj, name: str) -> shared_memory.SharedMemory:
"""Write data to shared memory""" """Write data to shared memory"""
serialized = json.dumps(data).encode("utf-8") serialized = pickle.dumps(obj)
size = len(serialized) size = len(serialized)
try: try:
# Try to open existing shared memory # Try to open existing shared memory
...@@ -588,22 +531,17 @@ def write_to_shared_memory(data: dict, name: str) -> shared_memory.SharedMemory: ...@@ -588,22 +531,17 @@ def write_to_shared_memory(data: dict, name: str) -> shared_memory.SharedMemory:
return shm return shm
def read_from_shared_memory(name: str) -> dict: def read_from_shared_memory(name: str) -> Any:
"""Read data from shared memory""" """Read data from shared memory"""
try: try:
shm = shared_memory.SharedMemory(name=name) shm = shared_memory.SharedMemory(name=name)
data = json.loads(bytes(shm.buf).decode("utf-8")) data = pickle.loads(bytes(shm.buf))
shm.close() shm.close()
return data return data
except FileNotFoundError: except FileNotFoundError:
raise FileNotFoundError(f"Shared memory {name} not found") raise FileNotFoundError(f"Shared memory {name} not found")
def get_main_process_id() -> int:
"""Get the main process ID"""
return multiprocessing.current_process()._parent_pid
def write_data_for_multi_tokenizer( def write_data_for_multi_tokenizer(
port_args: PortArgs, server_args: ServerArgs, scheduler_info: Dict port_args: PortArgs, server_args: ServerArgs, scheduler_info: Dict
): ):
...@@ -612,22 +550,8 @@ def write_data_for_multi_tokenizer( ...@@ -612,22 +550,8 @@ def write_data_for_multi_tokenizer(
main_pid = get_main_process_id() main_pid = get_main_process_id()
current_pid = os.getpid() current_pid = os.getpid()
logger.info(f"main process ID: {main_pid}, current process ID: {current_pid}") logger.info(f"main process ID: {main_pid}, current process ID: {current_pid}")
args = (port_args, server_args, scheduler_info)
args_shm = write_to_shared_memory(args, f"multi_tokenizer_args_{current_pid}")
args_shm.close()
# Write port_args to shared memory return args_shm
port_args_shm = write_to_shared_memory(
serialize_port_args(port_args), f"port_args_{current_pid}"
)
# Write server_args to shared memory
server_args_shm = write_to_shared_memory(
serialize_server_args(server_args), f"server_args_{current_pid}"
)
# Write scheduler_info to shared memory
scheduler_info_shm = write_to_shared_memory(
serialize_scheduler_info(scheduler_info), f"scheduler_info_{current_pid}"
)
port_args_shm.close()
server_args_shm.close()
scheduler_info_shm.close()
return port_args_shm, server_args_shm, scheduler_info_shm
...@@ -54,19 +54,14 @@ from fastapi import BackgroundTasks ...@@ -54,19 +54,14 @@ from fastapi import BackgroundTasks
from sglang.srt.aio_rwlock import RWLock from sglang.srt.aio_rwlock import RWLock
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.disaggregation.base import BaseKVBootstrapServer from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.disaggregation.utils import (
DisaggregationMode,
KVClassType,
TransferBackend,
get_kv_class,
)
from sglang.srt.hf_transformers_utils import ( from sglang.srt.hf_transformers_utils import (
get_processor, get_processor,
get_tokenizer, get_tokenizer,
get_tokenizer_from_processor, get_tokenizer_from_processor,
) )
from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
from sglang.srt.managers.disagg_service import start_disagg_service
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
AbortReq, AbortReq,
BatchEmbeddingOut, BatchEmbeddingOut,
...@@ -321,8 +316,10 @@ class TokenizerManager: ...@@ -321,8 +316,10 @@ class TokenizerManager:
# LoRA updates and inference to overlap. # LoRA updates and inference to overlap.
self.lora_update_lock = asyncio.Lock() self.lora_update_lock = asyncio.Lock()
# For PD disaggregtion self.disaggregation_mode = DisaggregationMode(
self.init_disaggregation() self.server_args.disaggregation_mode
)
self.bootstrap_server = start_disagg_service(self.server_args)
# For load balancing # For load balancing
self.current_load = 0 self.current_load = 0
...@@ -471,38 +468,6 @@ class TokenizerManager: ...@@ -471,38 +468,6 @@ class TokenizerManager:
] ]
) )
def init_disaggregation(self):
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
self.disaggregation_transfer_backend = TransferBackend(
self.server_args.disaggregation_transfer_backend
)
# Start kv boostrap server on prefill
if self.disaggregation_mode == DisaggregationMode.PREFILL:
# only start bootstrap server on prefill tm
kv_bootstrap_server_class: Type[BaseKVBootstrapServer] = get_kv_class(
self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
)
self.bootstrap_server: BaseKVBootstrapServer = kv_bootstrap_server_class(
host=self.server_args.host,
port=self.server_args.disaggregation_bootstrap_port,
)
is_create_store = (
self.server_args.node_rank == 0
and self.server_args.disaggregation_transfer_backend == "ascend"
)
if is_create_store:
try:
from mf_adapter import create_config_store
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
create_config_store(ascend_url)
except Exception as e:
error_message = f"Failed create mf store, invalid ascend_url."
error_message += f" With exception {e}"
raise error_message
async def generate_request( async def generate_request(
self, self,
obj: Union[GenerateReqInput, EmbeddingReqInput], obj: Union[GenerateReqInput, EmbeddingReqInput],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment