Unverified Commit e719bb0e authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

[1/2] Refactor multi-tokenizer manager (#10074)

parent 06724683
...@@ -704,6 +704,24 @@ def _set_envs_and_config(server_args: ServerArgs): ...@@ -704,6 +704,24 @@ def _set_envs_and_config(server_args: ServerArgs):
mp.set_start_method("spawn", force=True) mp.set_start_method("spawn", force=True)
def _init_tokenizer_manager(
server_args: ServerArgs, port_args: PortArgs
) -> TokenizerManager:
# Launch tokenizer process
tokenizer_manager = TokenizerManager(server_args, port_args)
# Initialize templates
template_manager = TemplateManager()
template_manager.initialize_templates(
tokenizer_manager=tokenizer_manager,
model_path=server_args.model_path,
chat_template=server_args.chat_template,
completion_template=server_args.completion_template,
)
return tokenizer_manager, template_manager
def _launch_subprocesses( def _launch_subprocesses(
server_args: ServerArgs, port_args: Optional[PortArgs] = None server_args: ServerArgs, port_args: Optional[PortArgs] = None
) -> Tuple[TokenizerManager, TemplateManager, Dict]: ) -> Tuple[TokenizerManager, TemplateManager, Dict]:
...@@ -816,23 +834,15 @@ def _launch_subprocesses( ...@@ -816,23 +834,15 @@ def _launch_subprocesses(
), ),
) )
detoken_proc.start() detoken_proc.start()
# Init tokenizer manager first, as the bootstrap server is initialized here
if server_args.tokenizer_worker_num > 1: if server_args.tokenizer_worker_num > 1:
# Launch multi-tokenizer router # Launch multi-tokenizer router
tokenizer_manager = MultiTokenizerRouter(server_args, port_args) tokenizer_manager = MultiTokenizerRouter(server_args, port_args)
# Initialize templates
template_manager = None template_manager = None
else: else:
# Launch tokenizer process tokenizer_manager, template_manager = _init_tokenizer_manager(
tokenizer_manager = TokenizerManager(server_args, port_args) server_args, port_args
# Initialize templates
template_manager = TemplateManager()
template_manager.initialize_templates(
tokenizer_manager=tokenizer_manager,
model_path=server_args.model_path,
chat_template=server_args.chat_template,
completion_template=server_args.completion_template,
) )
# Wait for the model to finish loading # Wait for the model to finish loading
...@@ -856,5 +866,7 @@ def _launch_subprocesses( ...@@ -856,5 +866,7 @@ def _launch_subprocesses(
# Assume all schedulers have the same scheduler_info # Assume all schedulers have the same scheduler_info
scheduler_info = scheduler_infos[0] scheduler_info = scheduler_infos[0]
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"] tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
return tokenizer_manager, template_manager, scheduler_info return tokenizer_manager, template_manager, scheduler_info
...@@ -92,7 +92,6 @@ from sglang.srt.managers.io_struct import ( ...@@ -92,7 +92,6 @@ from sglang.srt.managers.io_struct import (
) )
from sglang.srt.managers.multi_tokenizer_mixin import ( from sglang.srt.managers.multi_tokenizer_mixin import (
MultiTokenizerManager, MultiTokenizerManager,
deserialize_data,
get_main_process_id, get_main_process_id,
read_from_shared_memory, read_from_shared_memory,
write_data_for_multi_tokenizer, write_data_for_multi_tokenizer,
...@@ -136,21 +135,6 @@ def set_global_state(global_state: _GlobalState): ...@@ -136,21 +135,6 @@ def set_global_state(global_state: _GlobalState):
_global_state = global_state _global_state = global_state
# Function to set up all middlewares for multi-tokenizer compatibility
def setup_middlewares(api_key: Optional[str], enable_metrics: bool):
"""Setup all middlewares for both single and multi-process modes"""
worker_pid = os.getpid()
if api_key:
add_api_key_middleware(app, api_key)
logger.info(f"Worker {worker_pid} added API key middleware")
if enable_metrics:
add_prometheus_middleware(app)
enable_func_timer()
logger.info(f"Worker {worker_pid} added prometheus middleware")
async def init_multi_tokenizer() -> ServerArgs: async def init_multi_tokenizer() -> ServerArgs:
"""Read args information from shm and init tokenizer manager for current process""" """Read args information from shm and init tokenizer manager for current process"""
pid = os.getpid() pid = os.getpid()
...@@ -158,11 +142,15 @@ async def init_multi_tokenizer() -> ServerArgs: ...@@ -158,11 +142,15 @@ async def init_multi_tokenizer() -> ServerArgs:
logger.info(f"current worker_id: {pid}, main processID: {main_pid}") logger.info(f"current worker_id: {pid}, main processID: {main_pid}")
# Read configuration from shared memory # Read configuration from shared memory
port_args_data = read_from_shared_memory(f"port_args_{main_pid}") port_args, server_args, scheduler_info = read_from_shared_memory(
server_args_data = read_from_shared_memory(f"server_args_{main_pid}") f"multi_tokenizer_args_{main_pid}"
scheduler_info_data = read_from_shared_memory(f"scheduler_info_{main_pid}") )
port_args, server_args = deserialize_data(port_args_data, server_args_data) server_args: ServerArgs
scheduler_info = scheduler_info_data
# API key authentication is not supported in multi-tokenizer mode
assert (
server_args.api_key is None
), "API key is not supported in multi-tokenizer mode"
port_args.tokenizer_ipc_name = ( port_args.tokenizer_ipc_name = (
f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}" f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}"
...@@ -193,13 +181,17 @@ async def init_multi_tokenizer() -> ServerArgs: ...@@ -193,13 +181,17 @@ async def init_multi_tokenizer() -> ServerArgs:
@asynccontextmanager @asynccontextmanager
async def lifespan(fast_api_app: FastAPI): async def lifespan(fast_api_app: FastAPI):
server_args = getattr(fast_api_app, "server_args", None) if not getattr(fast_api_app, "is_single_tokenizer_mode", False):
if server_args is None:
# Initialize multi-tokenizer support for worker processes # Initialize multi-tokenizer support for worker processes
fast_api_app.server_args = await init_multi_tokenizer() fast_api_app.server_args: ServerArgs = await init_multi_tokenizer()
setup_middlewares(
fast_api_app.server_args.api_key, fast_api_app.server_args.enable_metrics # only metrics middleware is supported in multi-tokenizer mode
) worker_pid = os.getpid()
if fast_api_app.server_args.enable_metrics:
add_prometheus_middleware(app)
enable_func_timer()
logger.info(f"Worker {worker_pid} added prometheus middleware")
fast_api_app.warmup_thread = threading.Thread( fast_api_app.warmup_thread = threading.Thread(
target=_wait_and_warmup, target=_wait_and_warmup,
args=( args=(
...@@ -1187,12 +1179,10 @@ def launch_server( ...@@ -1187,12 +1179,10 @@ def launch_server(
) )
if server_args.tokenizer_worker_num > 1: if server_args.tokenizer_worker_num > 1:
port_args_shm, server_args_shm, scheduler_info_shm = ( multi_tokenizer_args_shm = write_data_for_multi_tokenizer(
write_data_for_multi_tokenizer( port_args,
port_args, server_args,
server_args, scheduler_info,
scheduler_info,
)
) )
else: else:
# Add api key authorization # Add api key authorization
...@@ -1239,6 +1229,7 @@ def launch_server( ...@@ -1239,6 +1229,7 @@ def launch_server(
workers=server_args.tokenizer_worker_num, workers=server_args.tokenizer_worker_num,
) )
else: else:
app.is_single_tokenizer_mode = True
uvicorn.run( uvicorn.run(
app, app,
host=server_args.host, host=server_args.host,
...@@ -1249,10 +1240,8 @@ def launch_server( ...@@ -1249,10 +1240,8 @@ def launch_server(
) )
finally: finally:
if server_args.tokenizer_worker_num > 1: if server_args.tokenizer_worker_num > 1:
port_args_shm.unlink() multi_tokenizer_args_shm.unlink()
server_args_shm.unlink() _global_state.tokenizer_manager.socket_mapping.clear_all_sockets()
scheduler_info_shm.unlink()
_global_state.tokenizer_manager.clear_tokenizer_mapping()
else: else:
warmup_thread.join() warmup_thread.join()
......
...@@ -34,7 +34,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -34,7 +34,7 @@ from sglang.srt.managers.io_struct import (
FreezeGCReq, FreezeGCReq,
MultiTokenizerRegisterReq, MultiTokenizerRegisterReq,
) )
from sglang.srt.managers.multi_tokenizer_mixin import MultiTokenizerMixin from sglang.srt.managers.multi_tokenizer_mixin import MultiHttpWorkerDetokenizerMixin
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
configure_logger, configure_logger,
...@@ -69,7 +69,7 @@ class DecodeStatus: ...@@ -69,7 +69,7 @@ class DecodeStatus:
sent_offset: int = 0 sent_offset: int = 0
class DetokenizerManager(MultiTokenizerMixin): class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
"""DetokenizerManager is a process that detokenizes the token ids.""" """DetokenizerManager is a process that detokenizes the token ids."""
def __init__( def __init__(
...@@ -289,11 +289,11 @@ def run_detokenizer_process( ...@@ -289,11 +289,11 @@ def run_detokenizer_process(
try: try:
manager = DetokenizerManager(server_args, port_args) manager = DetokenizerManager(server_args, port_args)
if server_args.tokenizer_worker_num > 1: if server_args.tokenizer_worker_num > 1:
manager.multi_tokenizer_manager_event_loop() manager.multi_http_worker_event_loop()
else: else:
manager.event_loop() manager.event_loop()
except Exception: except Exception:
manager.clear_tokenizer_mapping() manager.socket_mapping.clear_all_sockets()
traceback = get_exception_traceback() traceback = get_exception_traceback()
logger.error(f"DetokenizerManager hit an exception: {traceback}") logger.error(f"DetokenizerManager hit an exception: {traceback}")
parent_process.send_signal(signal.SIGQUIT) parent_process.send_signal(signal.SIGQUIT)
"""Start bootstrap/kv-store-related server"""
import os
from typing import Type
from sglang.srt.disaggregation.base import BaseKVBootstrapServer
from sglang.srt.disaggregation.utils import (
DisaggregationMode,
KVClassType,
TransferBackend,
get_kv_class,
)
from sglang.srt.server_args import ServerArgs
def start_disagg_service(
server_args: ServerArgs,
):
# Start kv boostrap server on prefill
disagg_mode = DisaggregationMode(server_args.disaggregation_mode)
transfer_backend = TransferBackend(server_args.disaggregation_transfer_backend)
if disagg_mode == DisaggregationMode.PREFILL:
# only start bootstrap server on prefill tm
kv_bootstrap_server_class: Type[BaseKVBootstrapServer] = get_kv_class(
transfer_backend, KVClassType.BOOTSTRAP_SERVER
)
bootstrap_server: BaseKVBootstrapServer = kv_bootstrap_server_class(
host=server_args.host,
port=server_args.disaggregation_bootstrap_port,
)
is_create_store = (
server_args.node_rank == 0 and transfer_backend == TransferBackend.ASCEND
)
if is_create_store:
try:
from mf_adapter import create_config_store
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
create_config_store(ascend_url)
except Exception as e:
error_message = f"Failed create mf store, invalid ascend_url."
error_message += f" With exception {e}"
raise error_message
return bootstrap_server
...@@ -54,19 +54,14 @@ from fastapi import BackgroundTasks ...@@ -54,19 +54,14 @@ from fastapi import BackgroundTasks
from sglang.srt.aio_rwlock import RWLock from sglang.srt.aio_rwlock import RWLock
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.disaggregation.base import BaseKVBootstrapServer from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.disaggregation.utils import (
DisaggregationMode,
KVClassType,
TransferBackend,
get_kv_class,
)
from sglang.srt.hf_transformers_utils import ( from sglang.srt.hf_transformers_utils import (
get_processor, get_processor,
get_tokenizer, get_tokenizer,
get_tokenizer_from_processor, get_tokenizer_from_processor,
) )
from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
from sglang.srt.managers.disagg_service import start_disagg_service
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
AbortReq, AbortReq,
BatchEmbeddingOut, BatchEmbeddingOut,
...@@ -321,8 +316,10 @@ class TokenizerManager: ...@@ -321,8 +316,10 @@ class TokenizerManager:
# LoRA updates and inference to overlap. # LoRA updates and inference to overlap.
self.lora_update_lock = asyncio.Lock() self.lora_update_lock = asyncio.Lock()
# For PD disaggregtion self.disaggregation_mode = DisaggregationMode(
self.init_disaggregation() self.server_args.disaggregation_mode
)
self.bootstrap_server = start_disagg_service(self.server_args)
# For load balancing # For load balancing
self.current_load = 0 self.current_load = 0
...@@ -471,38 +468,6 @@ class TokenizerManager: ...@@ -471,38 +468,6 @@ class TokenizerManager:
] ]
) )
def init_disaggregation(self):
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
self.disaggregation_transfer_backend = TransferBackend(
self.server_args.disaggregation_transfer_backend
)
# Start kv boostrap server on prefill
if self.disaggregation_mode == DisaggregationMode.PREFILL:
# only start bootstrap server on prefill tm
kv_bootstrap_server_class: Type[BaseKVBootstrapServer] = get_kv_class(
self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
)
self.bootstrap_server: BaseKVBootstrapServer = kv_bootstrap_server_class(
host=self.server_args.host,
port=self.server_args.disaggregation_bootstrap_port,
)
is_create_store = (
self.server_args.node_rank == 0
and self.server_args.disaggregation_transfer_backend == "ascend"
)
if is_create_store:
try:
from mf_adapter import create_config_store
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
create_config_store(ascend_url)
except Exception as e:
error_message = f"Failed create mf store, invalid ascend_url."
error_message += f" With exception {e}"
raise error_message
async def generate_request( async def generate_request(
self, self,
obj: Union[GenerateReqInput, EmbeddingReqInput], obj: Union[GenerateReqInput, EmbeddingReqInput],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment