Unverified Commit e719bb0e authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

[1/2] Refactor multi-tokenizer manager (#10074)

parent 06724683
......@@ -704,6 +704,24 @@ def _set_envs_and_config(server_args: ServerArgs):
mp.set_start_method("spawn", force=True)
def _init_tokenizer_manager(
server_args: ServerArgs, port_args: PortArgs
) -> TokenizerManager:
# Launch tokenizer process
tokenizer_manager = TokenizerManager(server_args, port_args)
# Initialize templates
template_manager = TemplateManager()
template_manager.initialize_templates(
tokenizer_manager=tokenizer_manager,
model_path=server_args.model_path,
chat_template=server_args.chat_template,
completion_template=server_args.completion_template,
)
return tokenizer_manager, template_manager
def _launch_subprocesses(
server_args: ServerArgs, port_args: Optional[PortArgs] = None
) -> Tuple[TokenizerManager, TemplateManager, Dict]:
......@@ -816,23 +834,15 @@ def _launch_subprocesses(
),
)
detoken_proc.start()
# Init tokenizer manager first, as the bootstrap server is initialized here
if server_args.tokenizer_worker_num > 1:
# Launch multi-tokenizer router
tokenizer_manager = MultiTokenizerRouter(server_args, port_args)
# Initialize templates
template_manager = None
else:
# Launch tokenizer process
tokenizer_manager = TokenizerManager(server_args, port_args)
# Initialize templates
template_manager = TemplateManager()
template_manager.initialize_templates(
tokenizer_manager=tokenizer_manager,
model_path=server_args.model_path,
chat_template=server_args.chat_template,
completion_template=server_args.completion_template,
tokenizer_manager, template_manager = _init_tokenizer_manager(
server_args, port_args
)
# Wait for the model to finish loading
......@@ -856,5 +866,7 @@ def _launch_subprocesses(
# Assume all schedulers have the same scheduler_info
scheduler_info = scheduler_infos[0]
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
return tokenizer_manager, template_manager, scheduler_info
......@@ -92,7 +92,6 @@ from sglang.srt.managers.io_struct import (
)
from sglang.srt.managers.multi_tokenizer_mixin import (
MultiTokenizerManager,
deserialize_data,
get_main_process_id,
read_from_shared_memory,
write_data_for_multi_tokenizer,
......@@ -136,21 +135,6 @@ def set_global_state(global_state: _GlobalState):
_global_state = global_state
# Function to set up all middlewares for multi-tokenizer compatibility
def setup_middlewares(api_key: Optional[str], enable_metrics: bool):
"""Setup all middlewares for both single and multi-process modes"""
worker_pid = os.getpid()
if api_key:
add_api_key_middleware(app, api_key)
logger.info(f"Worker {worker_pid} added API key middleware")
if enable_metrics:
add_prometheus_middleware(app)
enable_func_timer()
logger.info(f"Worker {worker_pid} added prometheus middleware")
async def init_multi_tokenizer() -> ServerArgs:
"""Read args information from shm and init tokenizer manager for current process"""
pid = os.getpid()
......@@ -158,11 +142,15 @@ async def init_multi_tokenizer() -> ServerArgs:
logger.info(f"current worker_id: {pid}, main processID: {main_pid}")
# Read configuration from shared memory
port_args_data = read_from_shared_memory(f"port_args_{main_pid}")
server_args_data = read_from_shared_memory(f"server_args_{main_pid}")
scheduler_info_data = read_from_shared_memory(f"scheduler_info_{main_pid}")
port_args, server_args = deserialize_data(port_args_data, server_args_data)
scheduler_info = scheduler_info_data
port_args, server_args, scheduler_info = read_from_shared_memory(
f"multi_tokenizer_args_{main_pid}"
)
server_args: ServerArgs
# API key authentication is not supported in multi-tokenizer mode
assert (
server_args.api_key is None
), "API key is not supported in multi-tokenizer mode"
port_args.tokenizer_ipc_name = (
f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}"
......@@ -193,13 +181,17 @@ async def init_multi_tokenizer() -> ServerArgs:
@asynccontextmanager
async def lifespan(fast_api_app: FastAPI):
server_args = getattr(fast_api_app, "server_args", None)
if server_args is None:
if not getattr(fast_api_app, "is_single_tokenizer_mode", False):
# Initialize multi-tokenizer support for worker processes
fast_api_app.server_args = await init_multi_tokenizer()
setup_middlewares(
fast_api_app.server_args.api_key, fast_api_app.server_args.enable_metrics
)
fast_api_app.server_args: ServerArgs = await init_multi_tokenizer()
# only metrics middleware is supported in multi-tokenizer mode
worker_pid = os.getpid()
if fast_api_app.server_args.enable_metrics:
add_prometheus_middleware(app)
enable_func_timer()
logger.info(f"Worker {worker_pid} added prometheus middleware")
fast_api_app.warmup_thread = threading.Thread(
target=_wait_and_warmup,
args=(
......@@ -1187,12 +1179,10 @@ def launch_server(
)
if server_args.tokenizer_worker_num > 1:
port_args_shm, server_args_shm, scheduler_info_shm = (
write_data_for_multi_tokenizer(
port_args,
server_args,
scheduler_info,
)
multi_tokenizer_args_shm = write_data_for_multi_tokenizer(
port_args,
server_args,
scheduler_info,
)
else:
# Add api key authorization
......@@ -1239,6 +1229,7 @@ def launch_server(
workers=server_args.tokenizer_worker_num,
)
else:
app.is_single_tokenizer_mode = True
uvicorn.run(
app,
host=server_args.host,
......@@ -1249,10 +1240,8 @@ def launch_server(
)
finally:
if server_args.tokenizer_worker_num > 1:
port_args_shm.unlink()
server_args_shm.unlink()
scheduler_info_shm.unlink()
_global_state.tokenizer_manager.clear_tokenizer_mapping()
multi_tokenizer_args_shm.unlink()
_global_state.tokenizer_manager.socket_mapping.clear_all_sockets()
else:
warmup_thread.join()
......
......@@ -34,7 +34,7 @@ from sglang.srt.managers.io_struct import (
FreezeGCReq,
MultiTokenizerRegisterReq,
)
from sglang.srt.managers.multi_tokenizer_mixin import MultiTokenizerMixin
from sglang.srt.managers.multi_tokenizer_mixin import MultiHttpWorkerDetokenizerMixin
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
configure_logger,
......@@ -69,7 +69,7 @@ class DecodeStatus:
sent_offset: int = 0
class DetokenizerManager(MultiTokenizerMixin):
class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
"""DetokenizerManager is a process that detokenizes the token ids."""
def __init__(
......@@ -289,11 +289,11 @@ def run_detokenizer_process(
try:
manager = DetokenizerManager(server_args, port_args)
if server_args.tokenizer_worker_num > 1:
manager.multi_tokenizer_manager_event_loop()
manager.multi_http_worker_event_loop()
else:
manager.event_loop()
except Exception:
manager.clear_tokenizer_mapping()
manager.socket_mapping.clear_all_sockets()
traceback = get_exception_traceback()
logger.error(f"DetokenizerManager hit an exception: {traceback}")
parent_process.send_signal(signal.SIGQUIT)
"""Start bootstrap/kv-store-related server"""
import os
from typing import Type
from sglang.srt.disaggregation.base import BaseKVBootstrapServer
from sglang.srt.disaggregation.utils import (
DisaggregationMode,
KVClassType,
TransferBackend,
get_kv_class,
)
from sglang.srt.server_args import ServerArgs
def start_disagg_service(
server_args: ServerArgs,
):
# Start kv boostrap server on prefill
disagg_mode = DisaggregationMode(server_args.disaggregation_mode)
transfer_backend = TransferBackend(server_args.disaggregation_transfer_backend)
if disagg_mode == DisaggregationMode.PREFILL:
# only start bootstrap server on prefill tm
kv_bootstrap_server_class: Type[BaseKVBootstrapServer] = get_kv_class(
transfer_backend, KVClassType.BOOTSTRAP_SERVER
)
bootstrap_server: BaseKVBootstrapServer = kv_bootstrap_server_class(
host=server_args.host,
port=server_args.disaggregation_bootstrap_port,
)
is_create_store = (
server_args.node_rank == 0 and transfer_backend == TransferBackend.ASCEND
)
if is_create_store:
try:
from mf_adapter import create_config_store
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
create_config_store(ascend_url)
except Exception as e:
error_message = f"Failed create mf store, invalid ascend_url."
error_message += f" With exception {e}"
raise error_message
return bootstrap_server
......@@ -13,21 +13,21 @@
# ==============================================================================
"""MultiTokenizerMixin is a class that provides nesscary methods for MultiTokenizerManager and DetokenizerManager."""
import asyncio
import dataclasses
import json
import logging
import multiprocessing as multiprocessing
import os
import pickle
import sys
import threading
from multiprocessing import shared_memory
from typing import Dict
from typing import Any, Dict
import setproctitle
import zmq
import zmq.asyncio
from sglang.srt.disaggregation.utils import DisaggregationMode, TransferBackend
from sglang.srt.managers.disagg_service import start_disagg_service
from sglang.srt.managers.io_struct import (
BatchEmbeddingOut,
BatchMultimodalOut,
......@@ -44,302 +44,296 @@ from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
class MultiTokenizerMixin:
"""Mixin class for MultiTokenizerManager and DetokenizerManager"""
class SocketMapping:
def __init__(self):
self._zmq_context = zmq.Context()
self._mapping: Dict[str, zmq.Socket] = {}
def create_sockets_mapping(self):
if not hasattr(self, "tokenizer_mapping"):
self.tokenizer_mapping = {}
# Create ZMQ context if needed
if not hasattr(self, "_zmq_context"):
self._zmq_context = zmq.Context()
def clear_all_sockets(self):
for socket in self._mapping.values():
socket.close()
self._mapping.clear()
def init_tokenizer_mapping(
self, recv_obj: MultiTokenizerRegisterReq, worker_id: str
def register_ipc_mapping(
self, recv_obj: MultiTokenizerRegisterReq, worker_id: str, is_tokenizer: bool
):
"""init tokenizer mapping from register request"""
ipc_name = recv_obj.ipc_name
worker_id_int = int(worker_id)
if worker_id_int not in self.tokenizer_mapping:
socket = get_zmq_socket(self._zmq_context, zmq.PUSH, ipc_name, False)
self.tokenizer_mapping[worker_id_int] = socket
self.tokenizer_mapping[worker_id_int].send_pyobj(recv_obj)
return True
else:
return False
def register_tokenizer_ipc(self, recv_obj, worker_id):
if worker_id not in self.tokenizer_mapping:
# register the worker if not already done
if isinstance(recv_obj, MultiTokenizerRegisterReq):
return self.init_tokenizer_mapping(recv_obj, worker_id)
else:
logger.error(
f"Worker {worker_id} not registered and not found in tokenizer mapping . "
"Please ensure the worker is registered correctly."
)
return False
def _handle_output_by_index(self, output, i):
"""NOTE: A maintainable method is better here."""
if isinstance(output, BatchTokenIDOut):
new_output = BatchTokenIDOut(
rids=[output.rids[i]],
finished_reasons=(
[output.finished_reasons[i]]
if len(output.finished_reasons) > i
else None
),
decoded_texts=(
[output.decoded_texts[i]] if len(output.decoded_texts) > i else None
),
decode_ids=(
[output.decode_ids[i]] if len(output.decode_ids) > i else None
),
read_offsets=(
[output.read_offsets[i]] if len(output.read_offsets) > i else None
),
output_ids=(
[output.output_ids[i]]
if output.output_ids and len(output.output_ids) > i
else None
),
skip_special_tokens=(
[output.skip_special_tokens[i]]
if len(output.skip_special_tokens) > i
else None
),
spaces_between_special_tokens=(
[output.spaces_between_special_tokens[i]]
if len(output.spaces_between_special_tokens) > i
else None
),
no_stop_trim=(
[output.no_stop_trim[i]] if len(output.no_stop_trim) > i else None
),
prompt_tokens=(
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
),
completion_tokens=(
[output.completion_tokens[i]]
if len(output.completion_tokens) > i
else None
),
cached_tokens=(
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
),
spec_verify_ct=(
[output.spec_verify_ct[i]]
if len(output.spec_verify_ct) > i
else None
),
input_token_logprobs_val=(
[output.input_token_logprobs_val[i]]
if output.input_token_logprobs_val
else None
),
input_token_logprobs_idx=(
[output.input_token_logprobs_idx[i]]
if output.input_token_logprobs_idx
else None
),
output_token_logprobs_val=(
[output.output_token_logprobs_val[i]]
if output.output_token_logprobs_val
else None
),
output_token_logprobs_idx=(
[output.output_token_logprobs_idx[i]]
if output.output_token_logprobs_idx
else None
),
input_top_logprobs_val=(
[output.input_top_logprobs_val[i]]
if output.input_top_logprobs_val
else None
),
input_top_logprobs_idx=(
[output.input_top_logprobs_idx[i]]
if output.input_top_logprobs_idx
else None
),
output_top_logprobs_val=(
[output.output_top_logprobs_val[i]]
if output.output_top_logprobs_val
else None
),
output_top_logprobs_idx=(
[output.output_top_logprobs_idx[i]]
if output.output_top_logprobs_idx
else None
),
input_token_ids_logprobs_val=(
[output.input_token_ids_logprobs_val[i]]
if output.input_token_ids_logprobs_val
else None
),
input_token_ids_logprobs_idx=(
[output.input_token_ids_logprobs_idx[i]]
if output.input_token_ids_logprobs_idx
else None
),
output_token_ids_logprobs_val=(
[output.output_token_ids_logprobs_val[i]]
if output.output_token_ids_logprobs_val
else None
),
output_token_ids_logprobs_idx=(
[output.output_token_ids_logprobs_idx[i]]
if output.output_token_ids_logprobs_idx
else None
),
output_hidden_states=(
[output.output_hidden_states[i]]
if output.output_hidden_states
else None
),
type_str = "tokenizer" if is_tokenizer else "detokenizer"
if worker_id in self._mapping:
logger.warning(
f"{type_str} already registered with worker {worker_id}, skipping..."
)
elif isinstance(output, BatchEmbeddingOut):
new_output = BatchEmbeddingOut(
rids=[output.rids[i]],
finished_reasons=(
[output.finished_reasons[i]]
if len(output.finished_reasons) > i
else None
),
embeddings=(
[output.embeddings[i]] if len(output.embeddings) > i else None
),
prompt_tokens=(
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
),
cached_tokens=(
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
),
)
elif isinstance(output, BatchStrOut):
new_output = BatchStrOut(
rids=[output.rids[i]],
finished_reasons=(
[output.finished_reasons[i]]
if len(output.finished_reasons) > i
else None
),
output_strs=(
[output.output_strs[i]] if len(output.output_strs) > i else None
),
output_ids=(
[output.output_ids[i]]
if output.output_ids and len(output.output_ids) > i
else None
),
prompt_tokens=(
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
),
completion_tokens=(
[output.completion_tokens[i]]
if len(output.completion_tokens) > i
else None
),
cached_tokens=(
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
),
spec_verify_ct=(
[output.spec_verify_ct[i]]
if len(output.spec_verify_ct) > i
else None
),
input_token_logprobs_val=(
[output.input_token_logprobs_val[i]]
if output.input_token_logprobs_val
else None
),
input_token_logprobs_idx=(
[output.input_token_logprobs_idx[i]]
if output.input_token_logprobs_idx
else None
),
output_token_logprobs_val=(
[output.output_token_logprobs_val[i]]
if output.output_token_logprobs_val
else None
),
output_token_logprobs_idx=(
[output.output_token_logprobs_idx[i]]
if output.output_token_logprobs_idx
else None
),
input_top_logprobs_val=(
[output.input_top_logprobs_val[i]]
if output.input_top_logprobs_val
else None
),
input_top_logprobs_idx=(
[output.input_top_logprobs_idx[i]]
if output.input_top_logprobs_idx
else None
),
output_top_logprobs_val=(
[output.output_top_logprobs_val[i]]
if output.output_top_logprobs_val
else None
),
output_top_logprobs_idx=(
[output.output_top_logprobs_idx[i]]
if output.output_top_logprobs_idx
else None
),
input_token_ids_logprobs_val=(
[output.input_token_ids_logprobs_val[i]]
if output.input_token_ids_logprobs_val
else None
),
input_token_ids_logprobs_idx=(
[output.input_token_ids_logprobs_idx[i]]
if output.input_token_ids_logprobs_idx
else None
),
output_token_ids_logprobs_val=(
[output.output_token_ids_logprobs_val[i]]
if output.output_token_ids_logprobs_val
else None
),
output_token_ids_logprobs_idx=(
[output.output_token_ids_logprobs_idx[i]]
if output.output_token_ids_logprobs_idx
else None
),
output_hidden_states=(
[output.output_hidden_states[i]]
if output.output_hidden_states
else None
),
)
elif isinstance(output, BatchMultimodalOut):
new_output = BatchMultimodalOut(
rids=[output.rids[i]],
finished_reasons=(
[output.finished_reasons[i]]
if len(output.finished_reasons) > i
else None
),
outputs=([output.outputs[i]] if len(output.outputs) > i else None),
prompt_tokens=(
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
),
completion_tokens=(
[output.completion_tokens[i]]
if len(output.completion_tokens) > i
else None
),
cached_tokens=(
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
),
return
logger.info(
f"{type_str} not registered with worker {worker_id}, registering..."
)
socket = get_zmq_socket(self._zmq_context, zmq.PUSH, recv_obj.ipc_name, False)
self._mapping[worker_id] = socket
self._mapping[worker_id].send_pyobj(recv_obj)
def send_output(self, worker_id: str, output: Any):
if worker_id not in self._mapping:
logger.error(
f"worker ID {worker_id} not registered. Check if the server Process is alive"
)
else:
new_output = output
return new_output
return
self._mapping[worker_id].send_pyobj(output)
def _handle_output_by_index(output, i):
"""NOTE: A maintainable method is better here."""
if isinstance(output, BatchTokenIDOut):
new_output = BatchTokenIDOut(
rids=[output.rids[i]],
finished_reasons=(
[output.finished_reasons[i]]
if len(output.finished_reasons) > i
else None
),
decoded_texts=(
[output.decoded_texts[i]] if len(output.decoded_texts) > i else None
),
decode_ids=([output.decode_ids[i]] if len(output.decode_ids) > i else None),
read_offsets=(
[output.read_offsets[i]] if len(output.read_offsets) > i else None
),
output_ids=(
[output.output_ids[i]]
if output.output_ids and len(output.output_ids) > i
else None
),
skip_special_tokens=(
[output.skip_special_tokens[i]]
if len(output.skip_special_tokens) > i
else None
),
spaces_between_special_tokens=(
[output.spaces_between_special_tokens[i]]
if len(output.spaces_between_special_tokens) > i
else None
),
no_stop_trim=(
[output.no_stop_trim[i]] if len(output.no_stop_trim) > i else None
),
prompt_tokens=(
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
),
completion_tokens=(
[output.completion_tokens[i]]
if len(output.completion_tokens) > i
else None
),
cached_tokens=(
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
),
spec_verify_ct=(
[output.spec_verify_ct[i]] if len(output.spec_verify_ct) > i else None
),
input_token_logprobs_val=(
[output.input_token_logprobs_val[i]]
if output.input_token_logprobs_val
else None
),
input_token_logprobs_idx=(
[output.input_token_logprobs_idx[i]]
if output.input_token_logprobs_idx
else None
),
output_token_logprobs_val=(
[output.output_token_logprobs_val[i]]
if output.output_token_logprobs_val
else None
),
output_token_logprobs_idx=(
[output.output_token_logprobs_idx[i]]
if output.output_token_logprobs_idx
else None
),
input_top_logprobs_val=(
[output.input_top_logprobs_val[i]]
if output.input_top_logprobs_val
else None
),
input_top_logprobs_idx=(
[output.input_top_logprobs_idx[i]]
if output.input_top_logprobs_idx
else None
),
output_top_logprobs_val=(
[output.output_top_logprobs_val[i]]
if output.output_top_logprobs_val
else None
),
output_top_logprobs_idx=(
[output.output_top_logprobs_idx[i]]
if output.output_top_logprobs_idx
else None
),
input_token_ids_logprobs_val=(
[output.input_token_ids_logprobs_val[i]]
if output.input_token_ids_logprobs_val
else None
),
input_token_ids_logprobs_idx=(
[output.input_token_ids_logprobs_idx[i]]
if output.input_token_ids_logprobs_idx
else None
),
output_token_ids_logprobs_val=(
[output.output_token_ids_logprobs_val[i]]
if output.output_token_ids_logprobs_val
else None
),
output_token_ids_logprobs_idx=(
[output.output_token_ids_logprobs_idx[i]]
if output.output_token_ids_logprobs_idx
else None
),
output_hidden_states=(
[output.output_hidden_states[i]]
if output.output_hidden_states
else None
),
)
elif isinstance(output, BatchEmbeddingOut):
new_output = BatchEmbeddingOut(
rids=[output.rids[i]],
finished_reasons=(
[output.finished_reasons[i]]
if len(output.finished_reasons) > i
else None
),
embeddings=([output.embeddings[i]] if len(output.embeddings) > i else None),
prompt_tokens=(
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
),
cached_tokens=(
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
),
)
elif isinstance(output, BatchStrOut):
new_output = BatchStrOut(
rids=[output.rids[i]],
finished_reasons=(
[output.finished_reasons[i]]
if len(output.finished_reasons) > i
else None
),
output_strs=(
[output.output_strs[i]] if len(output.output_strs) > i else None
),
output_ids=(
[output.output_ids[i]]
if output.output_ids and len(output.output_ids) > i
else None
),
prompt_tokens=(
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
),
completion_tokens=(
[output.completion_tokens[i]]
if len(output.completion_tokens) > i
else None
),
cached_tokens=(
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
),
spec_verify_ct=(
[output.spec_verify_ct[i]] if len(output.spec_verify_ct) > i else None
),
input_token_logprobs_val=(
[output.input_token_logprobs_val[i]]
if output.input_token_logprobs_val
else None
),
input_token_logprobs_idx=(
[output.input_token_logprobs_idx[i]]
if output.input_token_logprobs_idx
else None
),
output_token_logprobs_val=(
[output.output_token_logprobs_val[i]]
if output.output_token_logprobs_val
else None
),
output_token_logprobs_idx=(
[output.output_token_logprobs_idx[i]]
if output.output_token_logprobs_idx
else None
),
input_top_logprobs_val=(
[output.input_top_logprobs_val[i]]
if output.input_top_logprobs_val
else None
),
input_top_logprobs_idx=(
[output.input_top_logprobs_idx[i]]
if output.input_top_logprobs_idx
else None
),
output_top_logprobs_val=(
[output.output_top_logprobs_val[i]]
if output.output_top_logprobs_val
else None
),
output_top_logprobs_idx=(
[output.output_top_logprobs_idx[i]]
if output.output_top_logprobs_idx
else None
),
input_token_ids_logprobs_val=(
[output.input_token_ids_logprobs_val[i]]
if output.input_token_ids_logprobs_val
else None
),
input_token_ids_logprobs_idx=(
[output.input_token_ids_logprobs_idx[i]]
if output.input_token_ids_logprobs_idx
else None
),
output_token_ids_logprobs_val=(
[output.output_token_ids_logprobs_val[i]]
if output.output_token_ids_logprobs_val
else None
),
output_token_ids_logprobs_idx=(
[output.output_token_ids_logprobs_idx[i]]
if output.output_token_ids_logprobs_idx
else None
),
output_hidden_states=(
[output.output_hidden_states[i]]
if output.output_hidden_states
else None
),
)
elif isinstance(output, BatchMultimodalOut):
new_output = BatchMultimodalOut(
rids=[output.rids[i]],
finished_reasons=(
[output.finished_reasons[i]]
if len(output.finished_reasons) > i
else None
),
outputs=([output.outputs[i]] if len(output.outputs) > i else None),
prompt_tokens=(
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
),
completion_tokens=(
[output.completion_tokens[i]]
if len(output.completion_tokens) > i
else None
),
cached_tokens=(
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
),
)
else:
new_output = output
return new_output
class MultiHttpWorkerDetokenizerMixin:
"""Mixin class for MultiTokenizerManager and DetokenizerManager"""
def get_worker_ids_from_req_rids(self, rids):
if isinstance(rids, list):
......@@ -350,9 +344,9 @@ class MultiTokenizerMixin:
worker_ids = []
return worker_ids
def multi_tokenizer_manager_event_loop(self):
"""The event loop that handles requests, for multi tokenizer manager mode only"""
self.create_sockets_mapping()
def multi_http_worker_event_loop(self):
"""The event loop that handles requests, for multi multi-http-worker mode"""
self.socket_mapping = SocketMapping()
while True:
recv_obj = self.recv_from_scheduler.recv_pyobj()
output = self._request_dispatcher(recv_obj)
......@@ -369,31 +363,15 @@ class MultiTokenizerMixin:
# Send data using the corresponding socket
for i, worker_id in enumerate(worker_ids):
if isinstance(recv_obj, MultiTokenizerRegisterReq):
if self.register_tokenizer_ipc(recv_obj, worker_id):
logger.info(
f"DetokenizerManager Created ZMQ socket for worker {worker_id}"
)
continue
self.socket_mapping.register_ipc_mapping(
recv_obj, worker_id, is_tokenizer=False
)
else:
if worker_id not in self.tokenizer_mapping:
logger.error(
f"Tokenizer Worker ID {worker_id} not registered. Check if the server Process {worker_id} is alive"
)
continue
new_output = self._handle_output_by_index(output, i)
self.tokenizer_mapping[worker_id].send_pyobj(new_output)
def clear_tokenizer_mapping(self):
if hasattr(self, "tokenizer_mapping"):
for socket in self.tokenizer_mapping.values():
try:
socket.close()
except Exception as e:
logger.warning(f"Failed to close socket: {e}")
self.tokenizer_mapping.clear()
class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin):
new_output = _handle_output_by_index(output, i)
self.socket_mapping.send_output(worker_id, new_output)
class MultiTokenizerRouter:
"""A router to receive requests from MultiTokenizerManager"""
def __init__(
......@@ -422,7 +400,7 @@ class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin):
self._handle_task = asyncio.run_coroutine_threadsafe(
print_exception_wrapper(self.handle_loop), self._loop
)
self.init_disaggregation()
self.disaggregation_bootstrap_server = start_disagg_service(self.server_args)
def _run_loop(self):
self._loop.run_forever()
......@@ -434,7 +412,7 @@ class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin):
async def handle_loop(self):
# special reqs will recv from scheduler, need to route to right worker
self.create_sockets_mapping()
self.socket_mapping = SocketMapping()
while True:
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
await self._distribute_result_to_workers(recv_obj)
......@@ -454,22 +432,15 @@ class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin):
# Distribute result to each worker
for i, worker_id in enumerate(worker_ids):
if isinstance(recv_obj, MultiTokenizerRegisterReq):
if self.register_tokenizer_ipc(recv_obj, worker_id):
logger.info(
f"MultiTokenizerRouter Created ZMQ socket for worker {worker_id}"
)
continue
self.socket_mapping.register_ipc_mapping(
recv_obj, worker_id, is_tokenizer=True
)
else:
if worker_id not in self.tokenizer_mapping:
logger.error(
f"Tokenizer Worker ID {worker_id} not registered. Check if the server Process {worker_id} is alive"
)
continue
new_recv_obj = self._handle_output_by_index(recv_obj, i)
self.tokenizer_mapping[worker_id].send_pyobj(new_recv_obj)
new_recv_obj = _handle_output_by_index(recv_obj, i)
self.socket_mapping.send_output(worker_id, new_recv_obj)
class MultiTokenizerManager(TokenizerManager, MultiTokenizerMixin):
class MultiTokenizerManager(TokenizerManager):
"""Multi Process Tokenizer Manager that tokenizes the text."""
def __init__(
......@@ -535,42 +506,14 @@ async def print_exception_wrapper(func):
sys.exit(1)
def serialize_port_args(port_args: PortArgs) -> dict:
"""Serialize PortArgs into a shareable dictionary"""
return {
"tokenizer_ipc_name": port_args.tokenizer_ipc_name,
"scheduler_input_ipc_name": port_args.scheduler_input_ipc_name,
"detokenizer_ipc_name": port_args.detokenizer_ipc_name,
"nccl_port": port_args.nccl_port,
"rpc_ipc_name": port_args.rpc_ipc_name,
"metrics_ipc_name": port_args.metrics_ipc_name,
"tokenizer_worker_ipc_name": port_args.tokenizer_worker_ipc_name,
}
def deserialize_data(port_args: dict, server_args: dict):
"""Deserialize data from shared dictionaries"""
return PortArgs(**port_args), ServerArgs(**server_args)
def serialize_server_args(server_args: ServerArgs) -> dict:
"""Serialize ServerArgs into a shareable dictionary"""
return dataclasses.asdict(server_args)
def serialize_scheduler_info(scheduler_info: Dict) -> dict:
"""Serialize scheduler_info into a shareable dictionary"""
return scheduler_info
def deserialize_scheduler_info(data: dict) -> Dict:
"""Deserialize scheduler_info from a shared dictionary"""
return data
def get_main_process_id() -> int:
"""Get the main process ID"""
return multiprocessing.current_process()._parent_pid
def write_to_shared_memory(data: dict, name: str) -> shared_memory.SharedMemory:
def write_to_shared_memory(obj, name: str) -> shared_memory.SharedMemory:
"""Write data to shared memory"""
serialized = json.dumps(data).encode("utf-8")
serialized = pickle.dumps(obj)
size = len(serialized)
try:
# Try to open existing shared memory
......@@ -588,22 +531,17 @@ def write_to_shared_memory(data: dict, name: str) -> shared_memory.SharedMemory:
return shm
def read_from_shared_memory(name: str) -> dict:
def read_from_shared_memory(name: str) -> Any:
"""Read data from shared memory"""
try:
shm = shared_memory.SharedMemory(name=name)
data = json.loads(bytes(shm.buf).decode("utf-8"))
data = pickle.loads(bytes(shm.buf))
shm.close()
return data
except FileNotFoundError:
raise FileNotFoundError(f"Shared memory {name} not found")
def get_main_process_id() -> int:
"""Get the main process ID"""
return multiprocessing.current_process()._parent_pid
def write_data_for_multi_tokenizer(
port_args: PortArgs, server_args: ServerArgs, scheduler_info: Dict
):
......@@ -612,22 +550,8 @@ def write_data_for_multi_tokenizer(
main_pid = get_main_process_id()
current_pid = os.getpid()
logger.info(f"main process ID: {main_pid}, current process ID: {current_pid}")
args = (port_args, server_args, scheduler_info)
args_shm = write_to_shared_memory(args, f"multi_tokenizer_args_{current_pid}")
args_shm.close()
# Write port_args to shared memory
port_args_shm = write_to_shared_memory(
serialize_port_args(port_args), f"port_args_{current_pid}"
)
# Write server_args to shared memory
server_args_shm = write_to_shared_memory(
serialize_server_args(server_args), f"server_args_{current_pid}"
)
# Write scheduler_info to shared memory
scheduler_info_shm = write_to_shared_memory(
serialize_scheduler_info(scheduler_info), f"scheduler_info_{current_pid}"
)
port_args_shm.close()
server_args_shm.close()
scheduler_info_shm.close()
return port_args_shm, server_args_shm, scheduler_info_shm
return args_shm
......@@ -54,19 +54,14 @@ from fastapi import BackgroundTasks
from sglang.srt.aio_rwlock import RWLock
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.disaggregation.base import BaseKVBootstrapServer
from sglang.srt.disaggregation.utils import (
DisaggregationMode,
KVClassType,
TransferBackend,
get_kv_class,
)
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.hf_transformers_utils import (
get_processor,
get_tokenizer,
get_tokenizer_from_processor,
)
from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
from sglang.srt.managers.disagg_service import start_disagg_service
from sglang.srt.managers.io_struct import (
AbortReq,
BatchEmbeddingOut,
......@@ -321,8 +316,10 @@ class TokenizerManager:
# LoRA updates and inference to overlap.
self.lora_update_lock = asyncio.Lock()
# For PD disaggregtion
self.init_disaggregation()
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
self.bootstrap_server = start_disagg_service(self.server_args)
# For load balancing
self.current_load = 0
......@@ -471,38 +468,6 @@ class TokenizerManager:
]
)
def init_disaggregation(self):
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
self.disaggregation_transfer_backend = TransferBackend(
self.server_args.disaggregation_transfer_backend
)
# Start kv boostrap server on prefill
if self.disaggregation_mode == DisaggregationMode.PREFILL:
# only start bootstrap server on prefill tm
kv_bootstrap_server_class: Type[BaseKVBootstrapServer] = get_kv_class(
self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
)
self.bootstrap_server: BaseKVBootstrapServer = kv_bootstrap_server_class(
host=self.server_args.host,
port=self.server_args.disaggregation_bootstrap_port,
)
is_create_store = (
self.server_args.node_rank == 0
and self.server_args.disaggregation_transfer_backend == "ascend"
)
if is_create_store:
try:
from mf_adapter import create_config_store
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
create_config_store(ascend_url)
except Exception as e:
error_message = f"Failed create mf store, invalid ascend_url."
error_message += f" With exception {e}"
raise error_message
async def generate_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment