Unverified Commit 5f77e129 authored by ybyang's avatar ybyang Committed by GitHub
Browse files

Support Multi Process Tokenizer Manager(#6555) (#8964)


Signed-off-by: default avatarybyang <ybyang7@iflytek.com>
Signed-off-by: default avatarhuanglong <huanglong@linux.alibaba.com>
Co-authored-by: default avatarHuang Long <121648372+LLLL114@users.noreply.github.com>
Co-authored-by: default avatarhuanglong <huanglong@linux.alibaba.com>
Co-authored-by: default avatarShangming Cai <csmthu@gmail.com>
parent 4750cddf
......@@ -60,6 +60,7 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput,
)
from sglang.srt.managers.multi_tokenizer_mixin import MultiTokenizerRouter
from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.managers.template_manager import TemplateManager
from sglang.srt.managers.tokenizer_manager import TokenizerManager
......@@ -814,18 +815,24 @@ def _launch_subprocesses(
),
)
detoken_proc.start()
if server_args.tokenizer_worker_num > 1:
# Launch multi-tokenizer router
tokenizer_manager = MultiTokenizerRouter(server_args, port_args)
# Launch tokenizer process
tokenizer_manager = TokenizerManager(server_args, port_args)
# Initialize templates
template_manager = None
else:
# Launch tokenizer process
tokenizer_manager = TokenizerManager(server_args, port_args)
# Initialize templates
template_manager = TemplateManager()
template_manager.initialize_templates(
tokenizer_manager=tokenizer_manager,
model_path=server_args.model_path,
chat_template=server_args.chat_template,
completion_template=server_args.completion_template,
)
# Initialize templates
template_manager = TemplateManager()
template_manager.initialize_templates(
tokenizer_manager=tokenizer_manager,
model_path=server_args.model_path,
chat_template=server_args.chat_template,
completion_template=server_args.completion_template,
)
# Wait for the model to finish loading
scheduler_infos = []
......
......@@ -23,6 +23,7 @@ import json
import logging
import multiprocessing as multiprocessing
import os
import tempfile
import threading
import time
from http import HTTPStatus
......@@ -91,11 +92,18 @@ from sglang.srt.managers.io_struct import (
UpdateWeightVersionReqInput,
VertexGenerateReqInput,
)
from sglang.srt.managers.multi_tokenizer_mixin import (
MultiTokenizerManager,
deserialize_data,
get_main_process_id,
read_from_shared_memory,
write_data_for_multi_tokenizer,
)
from sglang.srt.managers.template_manager import TemplateManager
from sglang.srt.managers.tokenizer_manager import ServerStatus, TokenizerManager
from sglang.srt.metrics.func_timer import enable_func_timer
from sglang.srt.reasoning_parser import ReasoningParser
from sglang.srt.server_args import ServerArgs
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
add_api_key_middleware,
add_prometheus_middleware,
......@@ -130,8 +138,79 @@ def set_global_state(global_state: _GlobalState):
_global_state = global_state
# Function to set up all middlewares for multi-tokenizer compatibility
def setup_middlewares(api_key: Optional[str], enable_metrics: bool):
"""Setup all middlewares for both single and multi-process modes"""
worker_pid = os.getpid()
if api_key:
add_api_key_middleware(app, api_key)
logger.info(f"Worker {worker_pid} added API key middleware")
if enable_metrics:
add_prometheus_middleware(app)
enable_func_timer()
logger.info(f"Worker {worker_pid} added prometheus middleware")
async def init_multi_tokenizer() -> ServerArgs:
"""Read args information from shm and init tokenizer manager for current process"""
pid = os.getpid()
main_pid = get_main_process_id()
logger.info(f"current worker_id: {pid}, main processID: {main_pid}")
# Read configuration from shared memory
port_args_data = read_from_shared_memory(f"port_args_{main_pid}")
server_args_data = read_from_shared_memory(f"server_args_{main_pid}")
scheduler_info_data = read_from_shared_memory(f"scheduler_info_{main_pid}")
port_args, server_args = deserialize_data(port_args_data, server_args_data)
scheduler_info = scheduler_info_data
port_args.tokenizer_ipc_name = (
f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}"
)
# Launch multi-tokenizer manager process
tokenizer_manager = MultiTokenizerManager(server_args, port_args)
template_manager = TemplateManager()
template_manager.initialize_templates(
tokenizer_manager=tokenizer_manager,
model_path=server_args.model_path,
chat_template=server_args.chat_template,
completion_template=server_args.completion_template,
)
# Register this tokenizer with the main tokenizer manager
await tokenizer_manager.register_to_main_tokenizer_manager()
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
set_global_state(
_GlobalState(
tokenizer_manager=tokenizer_manager,
template_manager=template_manager,
scheduler_info=scheduler_info,
)
)
return server_args
@asynccontextmanager
async def lifespan(fast_api_app: FastAPI):
server_args = getattr(fast_api_app, "server_args", None)
if server_args is None:
# Initialize multi-tokenizer support for worker processes
fast_api_app.server_args = await init_multi_tokenizer()
setup_middlewares(
fast_api_app.server_args.api_key, fast_api_app.server_args.enable_metrics
)
fast_api_app.warmup_thread = threading.Thread(
target=_wait_and_warmup,
args=(
fast_api_app.server_args,
None, # pipe_finish_writer not needed in worker
None, # launch_callback not needed in worker
),
)
# Initialize OpenAI serving handlers
fast_api_app.state.openai_serving_completion = OpenAIServingCompletion(
_global_state.tokenizer_manager, _global_state.template_manager
......@@ -191,7 +270,15 @@ async def lifespan(fast_api_app: FastAPI):
warmup_thread = getattr(fast_api_app, "warmup_thread", None)
if warmup_thread is not None:
warmup_thread.start()
yield
try:
yield
finally:
if server_args.tokenizer_worker_num > 1:
pid = os.getpid()
logger.info(f"uvicorn worker {pid} ending...")
warmup_thread.join()
logger.info(f"uvicorn worker {pid} ended.")
# Fast API
......@@ -1078,9 +1165,19 @@ def launch_server(
1. The HTTP server, Engine, and TokenizerManager both run in the main process.
2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library.
"""
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
server_args=server_args
)
if server_args.tokenizer_worker_num > 1:
port_args = PortArgs.init_new(server_args)
port_args.tokenizer_worker_ipc_name = (
f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}"
)
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
server_args=server_args, port_args=port_args
)
else:
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
server_args=server_args,
)
set_global_state(
_GlobalState(
tokenizer_manager=tokenizer_manager,
......@@ -1089,42 +1186,75 @@ def launch_server(
)
)
# Add api key authorization
if server_args.api_key:
add_api_key_middleware(app, server_args.api_key)
# Add prometheus middleware
if server_args.enable_metrics:
add_prometheus_middleware(app)
enable_func_timer()
# Send a warmup request - we will create the thread launch it
# in the lifespan after all other warmups have fired.
warmup_thread = threading.Thread(
target=_wait_and_warmup,
args=(
server_args,
pipe_finish_writer,
launch_callback,
),
)
app.warmup_thread = warmup_thread
if server_args.tokenizer_worker_num > 1:
port_args_shm, server_args_shm, scheduler_info_shm = (
write_data_for_multi_tokenizer(
port_args,
server_args,
scheduler_info,
)
)
else:
# Add api key authorization
if server_args.api_key:
add_api_key_middleware(app, server_args.api_key)
# Add prometheus middleware
if server_args.enable_metrics:
add_prometheus_middleware(app)
enable_func_timer()
# Send a warmup request - we will create the thread launch it
# in the lifespan after all other warmups have fired.
warmup_thread = threading.Thread(
target=_wait_and_warmup,
args=(
server_args,
pipe_finish_writer,
launch_callback,
),
)
app.warmup_thread = warmup_thread
try:
# Update logging configs
set_uvicorn_logging_configs()
app.server_args = server_args
# Listen for HTTP requests
uvicorn.run(
app,
host=server_args.host,
port=server_args.port,
log_level=server_args.log_level_http or server_args.log_level,
timeout_keep_alive=5,
loop="uvloop",
)
if server_args.tokenizer_worker_num > 1:
from uvicorn.config import LOGGING_CONFIG
LOGGING_CONFIG["loggers"]["sglang.srt.entrypoints.http_server"] = {
"handlers": ["default"],
"level": "INFO",
"propagate": False,
}
uvicorn.run(
"sglang.srt.entrypoints.http_server:app",
host=server_args.host,
port=server_args.port,
log_level=server_args.log_level_http or server_args.log_level,
timeout_keep_alive=5,
loop="uvloop",
workers=server_args.tokenizer_worker_num,
)
else:
uvicorn.run(
app,
host=server_args.host,
port=server_args.port,
log_level=server_args.log_level_http or server_args.log_level,
timeout_keep_alive=5,
loop="uvloop",
)
finally:
warmup_thread.join()
if server_args.tokenizer_worker_num > 1:
port_args_shm.unlink()
server_args_shm.unlink()
scheduler_info_shm.unlink()
_global_state.tokenizer_manager.clear_tokenizer_mapping()
else:
warmup_thread.join()
def _execute_server_warmup(
......
......@@ -32,11 +32,14 @@ from sglang.srt.managers.io_struct import (
BatchStrOut,
BatchTokenIDOut,
FreezeGCReq,
MultiTokenizerRegisterReq,
)
from sglang.srt.managers.multi_tokenizer_mixin import MultiTokenizerMixin
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
configure_logger,
freeze_gc,
get_worker_ids_from_req_rids,
get_zmq_socket,
kill_itself_when_parent_died,
)
......@@ -67,7 +70,7 @@ class DecodeStatus:
sent_offset: int = 0
class DetokenizerManager:
class DetokenizerManager(MultiTokenizerMixin):
"""DetokenizerManager is a process that detokenizes the token ids."""
def __init__(
......@@ -102,6 +105,7 @@ class DetokenizerManager:
(BatchEmbeddingOut, self.handle_batch_embedding_out),
(BatchTokenIDOut, self.handle_batch_token_id_out),
(BatchMultimodalDecodeReq, self.handle_multimodal_decode_req),
(MultiTokenizerRegisterReq, lambda x: x),
(FreezeGCReq, self.handle_freeze_gc_req),
]
)
......@@ -116,6 +120,39 @@ class DetokenizerManager:
if output is not None:
self.send_to_tokenizer.send_pyobj(output)
def multi_tokenizer_manager_event_loop(self):
"""The event loop that handles requests, for multi tokenizer manager mode only"""
self.create_sockets_mapping()
while True:
recv_obj = self.recv_from_scheduler.recv_pyobj()
output = self._request_dispatcher(recv_obj)
if output is None:
continue
# Extract worker_id from rid
if isinstance(recv_obj.rids, list):
worker_ids = get_worker_ids_from_req_rids(recv_obj.rids)
else:
raise RuntimeError(
f"for tokenizer_worker_num > 1, recv_obj.rids must be a list"
)
# Send data using the corresponding socket
for i, worker_id in enumerate(worker_ids):
if isinstance(recv_obj, MultiTokenizerRegisterReq):
if self.register_tokenizer_ipc(recv_obj, worker_id):
logger.info(
f"DetokenizerManager Created ZMQ socket for worker {worker_id}"
)
continue
else:
if worker_id not in self.tokenizer_mapping:
logger.error(
f"Tokenizer Worker ID {worker_id} not registered. Check if the server Process {worker_id} is alive"
)
continue
new_output = self._handle_output_by_index(output, i)
self.tokenizer_mapping[worker_id].send_pyobj(new_output)
def trim_matched_stop(
self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
):
......@@ -285,8 +322,12 @@ def run_detokenizer_process(
try:
manager = DetokenizerManager(server_args, port_args)
manager.event_loop()
if server_args.tokenizer_worker_num > 1:
manager.multi_tokenizer_manager_event_loop()
else:
manager.event_loop()
except Exception:
manager.clear_tokenizer_mapping()
traceback = get_exception_traceback()
logger.error(f"DetokenizerManager hit an exception: {traceback}")
parent_process.send_signal(signal.SIGQUIT)
......@@ -983,6 +983,11 @@ class AbortReq:
abort_all: bool = False
# The finished reason data
finished_reason: Optional[Dict[str, Any]] = None
# used in MultiTokenzierManager mode
rids: Optional[Union[List[str], str]] = None
def __post_init__(self):
self.rids = self.rid
@dataclass
......@@ -1183,6 +1188,18 @@ class LoRAUpdateResult:
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
@dataclass
class MultiTokenizerRegisterReq:
rids: Optional[Union[List[str], str]] = None
ipc_name: Optional[str] = None
@dataclass
class MultiTokenizerWarpper:
worker_id: int
obj: Optional[Any] = None
class BlockReqType(Enum):
BLOCK = 1
UNBLOCK = 2
......
This diff is collapsed.
......@@ -84,6 +84,8 @@ from sglang.srt.managers.io_struct import (
InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterReqInput,
LoadLoRAAdapterReqOutput,
MultiTokenizerRegisterReq,
MultiTokenizerWarpper,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
......@@ -257,7 +259,6 @@ class Scheduler(
# Init inter-process communication
context = zmq.Context(2)
self.idle_sleeper = None
if self.pp_rank == 0 and self.attn_tp_rank == 0:
self.recv_from_tokenizer = get_zmq_socket(
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
......@@ -540,6 +541,7 @@ class Scheduler(
(ExpertDistributionReq, self.expert_distribution_handle),
(LoadLoRAAdapterReqInput, self.load_lora_adapter),
(UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
(MultiTokenizerRegisterReq, self.register_multi_tokenizer),
]
)
......@@ -1101,6 +1103,17 @@ class Scheduler(
)
self.send_to_tokenizer.send_pyobj(abort_req)
continue
# If it is a MultiTokenizerWarpper, unwrap it and handle the inner request.
if isinstance(recv_req, MultiTokenizerWarpper):
worker_id = recv_req.worker_id
recv_req = recv_req.obj
output = self._request_dispatcher(recv_req)
if output is not None:
output = MultiTokenizerWarpper(worker_id, output)
self.send_to_tokenizer.send_pyobj(output)
continue
output = self._request_dispatcher(recv_req)
if output is not None:
if isinstance(output, RpcReqOutput):
......@@ -2474,6 +2487,10 @@ class Scheduler(
result = self.tp_worker.unload_lora_adapter(recv_req)
return result
def register_multi_tokenizer(self, recv_req: MultiTokenizerRegisterReq):
self.send_to_detokenizer.send_pyobj(recv_req)
return recv_req
def slow_down(self, recv_req: SlowDownReqInput):
t = recv_req.forward_sleep_time
if t is not None and t <= 0:
......
......@@ -94,6 +94,7 @@ from sglang.srt.managers.io_struct import (
LoadLoRAAdapterReqInput,
LoadLoRAAdapterReqOutput,
LoRAUpdateResult,
MultiTokenizerWarpper,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
......@@ -131,6 +132,7 @@ from sglang.srt.utils import (
dataclass_to_string_truncated,
freeze_gc,
get_bool_env_var,
get_origin_rid,
get_zmq_socket,
kill_process_tree,
)
......@@ -266,9 +268,15 @@ class TokenizerManager:
self.recv_from_detokenizer = get_zmq_socket(
context, zmq.PULL, port_args.tokenizer_ipc_name, True
)
self.send_to_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
)
if self.server_args.tokenizer_worker_num > 1:
# Use tokenizer_worker_ipc_name in multi-tokenizer mode
self.send_to_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.tokenizer_worker_ipc_name, False
)
else:
self.send_to_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
)
# Request states
self.no_create_loop = False
......@@ -312,35 +320,7 @@ class TokenizerManager:
self.lora_update_lock = asyncio.Lock()
# For PD disaggregtion
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
self.disaggregation_transfer_backend = TransferBackend(
self.server_args.disaggregation_transfer_backend
)
# Start kv boostrap server on prefill
if self.disaggregation_mode == DisaggregationMode.PREFILL:
# only start bootstrap server on prefill tm
kv_bootstrap_server_class = get_kv_class(
self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
)
self.bootstrap_server = kv_bootstrap_server_class(
self.server_args.disaggregation_bootstrap_port
)
is_create_store = (
self.server_args.node_rank == 0
and self.server_args.disaggregation_transfer_backend == "ascend"
)
if is_create_store:
try:
from mf_adapter import create_config_store
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
create_config_store(ascend_url)
except Exception as e:
error_message = f"Failed create mf store, invalid ascend_url."
error_message += f" With exception {e}"
raise error_message
self.init_disaggregation()
# For load balancing
self.current_load = 0
......@@ -488,6 +468,37 @@ class TokenizerManager:
]
)
def init_disaggregation(self):
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
self.disaggregation_transfer_backend = TransferBackend(
self.server_args.disaggregation_transfer_backend
)
# Start kv boostrap server on prefill
if self.disaggregation_mode == DisaggregationMode.PREFILL:
# only start bootstrap server on prefill tm
kv_bootstrap_server_class = get_kv_class(
self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
)
self.bootstrap_server = kv_bootstrap_server_class(
self.server_args.disaggregation_bootstrap_port
)
is_create_store = (
self.server_args.node_rank == 0
and self.server_args.disaggregation_transfer_backend == "ascend"
)
if is_create_store:
try:
from mf_adapter import create_config_store
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
create_config_store(ascend_url)
except Exception as e:
error_message = f"Failed create mf store, invalid ascend_url."
error_message += f" With exception {e}"
raise error_message
async def generate_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
......@@ -497,6 +508,15 @@ class TokenizerManager:
self.auto_create_handle_loop()
obj.normalize_batch_and_arguments()
if self.server_args.tokenizer_worker_num > 1:
# Modify rid, add worker_id
if isinstance(obj.rid, list):
# If it's an array, add worker_id prefix to each element
obj.rid = [f"{self.worker_id}_{rid}" for rid in obj.rid]
else:
# If it's a single value, add worker_id prefix
obj.rid = f"{self.worker_id}_{obj.rid}"
if self.log_requests:
max_length, skip_names, _ = self.log_request_metadata
logger.info(
......@@ -1096,6 +1116,8 @@ class TokenizerManager:
async def _wait_for_model_update_from_disk(
self, obj: UpdateWeightFromDiskReqInput
) -> Tuple[bool, str]:
if self.server_args.tokenizer_worker_num > 1:
obj = MultiTokenizerWarpper(self.worker_id, obj)
self.send_to_scheduler.send_pyobj(obj)
self.model_update_result = asyncio.Future()
if self.server_args.dp_size == 1:
......@@ -1315,6 +1337,8 @@ class TokenizerManager:
elif obj.session_id in self.session_futures:
return None
if self.server_args.tokenizer_worker_num > 1:
obj = MultiTokenizerWarpper(self.worker_id, obj)
self.send_to_scheduler.send_pyobj(obj)
self.session_futures[obj.session_id] = asyncio.Future()
......@@ -1590,7 +1614,6 @@ class TokenizerManager:
async def handle_loop(self):
"""The event loop that handles requests"""
while True:
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
self._result_dispatcher(recv_obj)
......@@ -1610,9 +1633,12 @@ class TokenizerManager:
)
continue
origin_rid = rid
if self.server_args.tokenizer_worker_num > 1:
origin_rid = get_origin_rid(rid)
# Build meta_info and return value
meta_info = {
"id": rid,
"id": origin_rid,
"finish_reason": recv_obj.finished_reasons[i],
"prompt_tokens": recv_obj.prompt_tokens[i],
"weight_version": self.server_args.weight_version,
......@@ -1918,6 +1944,9 @@ class TokenizerManager:
if is_health_check_generate_req(recv_obj):
return
state = self.rid_to_state[recv_obj.rid]
origin_rid = recv_obj.rid
if self.server_args.tokenizer_worker_num > 1:
origin_rid = get_origin_rid(origin_rid)
state.finished = True
if recv_obj.finished_reason:
out = {
......@@ -1930,7 +1959,7 @@ class TokenizerManager:
out = {
"text": "",
"meta_info": {
"id": recv_obj.rid,
"id": origin_rid,
"finish_reason": {
"type": "abort",
"message": "Abort before prefill",
......@@ -2116,6 +2145,8 @@ T = TypeVar("T")
class _Communicator(Generic[T]):
"""Note: The communicator now only run up to 1 in-flight request at any time."""
enable_multi_tokenizer = False
def __init__(self, sender, fan_out: int):
self._sender = sender
self._fan_out = fan_out
......@@ -2132,6 +2163,8 @@ class _Communicator(Generic[T]):
assert self._result_values is None
if obj:
if _Communicator.enable_multi_tokenizer:
obj = MultiTokenizerWarpper(worker_id=os.getpid(), obj=obj)
self._sender.send_pyobj(obj)
self._result_event = asyncio.Event()
......
......@@ -128,6 +128,7 @@ class ServerArgs:
model_path: str
tokenizer_path: Optional[str] = None
tokenizer_mode: str = "auto"
tokenizer_worker_num: int = 1
skip_tokenizer_init: bool = False
load_format: str = "auto"
model_loader_extra_config: str = "{}"
......@@ -827,6 +828,12 @@ class ServerArgs:
default=ServerArgs.tokenizer_path,
help="The path of the tokenizer.",
)
parser.add_argument(
"--tokenizer-worker-num",
type=int,
default=ServerArgs.tokenizer_worker_num,
help="The worker num of the tokenizer manager.",
)
parser.add_argument(
"--tokenizer-mode",
type=str,
......@@ -2176,6 +2183,9 @@ class ServerArgs:
self.chunked_prefill_size % self.page_size == 0
), "chunked_prefill_size must be divisible by page_size"
# Check multi tokenizer
assert self.tokenizer_worker_num > 0, "Tokenizer worker num must >= 1"
def check_lora_server_args(self):
assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive"
......@@ -2419,6 +2429,9 @@ class PortArgs:
# The ipc filename for Scheduler to send metrics
metrics_ipc_name: str
# The ipc filename for Tokenizer and worker tokenizer
tokenizer_worker_ipc_name: Optional[str]
@staticmethod
def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
if server_args.nccl_port is None:
......@@ -2442,6 +2455,7 @@ class PortArgs:
nccl_port=nccl_port,
rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
metrics_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
tokenizer_worker_ipc_name=None,
)
else:
# DP attention. Use TCP + port to handle both single-node and multi-node.
......@@ -2475,6 +2489,7 @@ class PortArgs:
nccl_port=nccl_port,
rpc_ipc_name=f"tcp://{dist_init_host}:{rpc_port}",
metrics_ipc_name=f"tcp://{dist_init_host}:{metrics_ipc_name}",
tokenizer_worker_ipc_name=None,
)
......
......@@ -2787,6 +2787,20 @@ def lru_cache_frozenset(maxsize=128):
return decorator
def get_worker_ids_from_req_rids(rids):
if isinstance(rids, list):
worker_ids = [int(rid.split("_")[0]) for rid in rids]
elif isinstance(rids, str):
worker_ids = [int(rids.split("_")[0])]
else:
worker_ids = []
return worker_ids
def get_origin_rid(rid):
return rid.split("_", 1)[1] if "_" in rid else rid
def apply_module_patch(target_module, target_function, wrappers):
original_module, original_function = parse_module_path(
target_module, target_function, False
......
......@@ -85,6 +85,7 @@ suites = {
TestFile("test_mla_int8_deepseek_v3.py", 429),
TestFile("test_mla_flashinfer.py", 302),
TestFile("test_mla_fp8.py", 93),
TestFile("test_multi_tokenizer.py", 230),
TestFile("test_no_chunked_prefill.py", 108),
TestFile("test_no_overlap_scheduler.py", 234),
TestFile("test_original_logprobs.py", 200),
......
import unittest
from types import SimpleNamespace
import sglang.srt.managers.io_struct as io_struct
from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
auto_config_device,
get_benchmark_args,
is_in_ci,
popen_launch_server,
run_benchmark,
write_github_step_summary,
)
class TestMultiTokenizer(CustomTestCase):
# from test_hicache.py
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--tokenizer-worker-num",
8,
"--mem-fraction-static",
0.7,
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
self.assertGreaterEqual(metrics["score"], 0.65)
def test_multi_tokenizer_ttft(self):
# from test_bench_serving.py run_bench_serving
args = get_benchmark_args(
base_url=self.base_url,
dataset_name="random",
dataset_path="",
tokenizer=None,
num_prompts=100,
random_input_len=4096,
random_output_len=2048,
sharegpt_context_len=None,
request_rate=1,
disable_stream=False,
disable_ignore_eos=False,
seed=0,
device=auto_config_device(),
lora_name=None,
)
res = run_benchmark(args)
if is_in_ci():
write_github_step_summary(
f"### test_multi_tokenizer_ttft\n"
f"median_e2e_latency_ms: {res['median_e2e_latency_ms']:.2f} ms\n"
)
self.assertLess(res["median_e2e_latency_ms"], 11000)
self.assertLess(res["median_ttft_ms"], 86)
self.assertLess(res["median_itl_ms"], 10)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment