"docs/source/vscode:/vscode.git/clone" did not exist on "9430bec6262830afa31b38aa25c07679e733d9c9"
Unverified Commit fe97a2d4 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Simplify tokenizer manager (#2254)

parent 8b48496a
......@@ -21,14 +21,13 @@ from typing import Dict, List, Optional, Tuple
import numpy as np
from sglang.api import Engine
from sglang.bench_serving import (
get_dataset,
get_tokenizer,
sample_random_requests,
set_ulimit,
)
from sglang.srt.server import Runtime, start_profile, stop_profile
from sglang.srt.server import Engine, Runtime
from sglang.srt.server_args import ServerArgs
......@@ -204,12 +203,12 @@ def throughput_test_once(
st = time.perf_counter()
if profile:
start_profile()
backend.start_profile()
gen_out = backend.generate(prompt=prompt, sampling_params=sampling_params)
if profile:
stop_profile()
backend.stop_profile()
monitor_trace_file(os.getenv("SGLANG_TORCH_PROFILER_DIR"))
latency = time.perf_counter() - st
......
......@@ -338,7 +338,7 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
"pixel_values": pixel_values,
"image_hashes": image_hashes,
"image_sizes": image_sizes,
"modalities": request_obj.modalities,
"modalities": request_obj.modalities or ["image"],
"image_grid_thws": image_grid_thws,
}
......
......@@ -376,16 +376,6 @@ class ProfileReq(Enum):
STOP_PROFILE = 2
@dataclass
class GetMemPoolSizeReq:
pass
@dataclass
class GetMemPoolSizeReqOutput:
size: int
@dataclass
class OpenSessionReqInput:
capacity_of_str_len: int
......
......@@ -38,8 +38,6 @@ from sglang.srt.managers.io_struct import (
BatchTokenIDOut,
CloseSessionReqInput,
FlushCacheReq,
GetMemPoolSizeReq,
GetMemPoolSizeReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
......@@ -521,10 +519,6 @@ class Scheduler:
self.send_to_tokenizer.send_pyobj(OpenSessionReqOutput(session_id))
elif isinstance(recv_req, CloseSessionReqInput):
self.close_session(recv_req)
elif isinstance(recv_req, GetMemPoolSizeReq):
self.send_to_tokenizer.send_pyobj(
GetMemPoolSizeReqOutput(self.max_total_num_tokens)
)
else:
raise ValueError(f"Invalid request: {recv_req}")
......
......@@ -45,8 +45,6 @@ from sglang.srt.managers.io_struct import (
EmbeddingReqInput,
FlushCacheReq,
GenerateReqInput,
GetMemPoolSizeReq,
GetMemPoolSizeReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
......@@ -218,7 +216,7 @@ class TokenizerManager:
input_ids = obj.input_ids
if self.is_generation:
image_inputs = await self.image_processor.process_images_async(
image_inputs: Dict = await self.image_processor.process_images_async(
obj.image_data, input_text or input_ids, obj
)
if image_inputs and "input_ids" in image_inputs:
......@@ -406,25 +404,6 @@ class TokenizerManager:
req = ProfileReq.STOP_PROFILE
self.send_to_scheduler.send_pyobj(req)
async def get_memory_pool_size(self):
if self.to_create_loop:
self.create_handle_loop()
req = GetMemPoolSizeReq()
self.send_to_scheduler.send_pyobj(req)
self.mem_pool_size = asyncio.Future()
# FIXME: Each request should have its own future instead of using `self.mem_pool_size`.
if self.server_args.dp_size == 1:
res = await self.mem_pool_size
return res.size
else: # self.server_args.dp_size > 1
self.mem_pool_size_tmp = []
res = await self.mem_pool_size
ret = [r.size for r in res]
return ret
async def update_weights(
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
):
......@@ -552,15 +531,6 @@ class TokenizerManager:
if len(self.model_update_tmp) == self.server_args.dp_size:
self.model_update_result.set_result(self.model_update_tmp)
continue
elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
if self.server_args.dp_size == 1:
self.mem_pool_size.set_result(recv_obj)
else: # self.sever_args.dp_size > 1
self.mem_pool_size_tmp.append(recv_obj)
# set future if the all results are received
if len(self.mem_pool_size_tmp) == self.server_args.dp_size:
self.mem_pool_size.set_result(self.mem_pool_size_tmp)
continue
elif isinstance(recv_obj, OpenSessionReqOutput):
self.session_futures[recv_obj.session_id].set_result(
recv_obj.session_id
......
......@@ -24,7 +24,6 @@ import logging
import multiprocessing as mp
import os
import signal
import sys
import threading
import time
from http import HTTPStatus
......@@ -94,7 +93,7 @@ logger = logging.getLogger(__name__)
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
# Fast API
app = FastAPI()
app.add_middleware(
CORSMiddleware,
......@@ -105,7 +104,7 @@ app.add_middleware(
)
tokenizer_manager: TokenizerManager = None
_max_total_num_tokens = None
scheduler_info: Dict = None
##### Native API endpoints #####
......@@ -171,16 +170,6 @@ async def flush_cache():
)
def start_profile():
"""Start profiling."""
tokenizer_manager.start_profile()
def stop_profile():
"""Stop profiling."""
tokenizer_manager.stop_profile()
@app.get("/start_profile")
@app.post("/start_profile")
async def start_profile_async():
......@@ -245,6 +234,8 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
)
# fastapi implicitly converts json in the request to obj (dataclass)
@app.api_route("/generate", methods=["POST", "PUT"])
@time_func_latency
async def generate_request(obj: GenerateReqInput, request: Request):
"""Handle a generate request."""
......@@ -278,11 +269,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
)
# fastapi implicitly converts json in the request to obj (dataclass)
app.post("/generate")(generate_request)
app.put("/generate")(generate_request)
@app.api_route("/encode", methods=["POST", "PUT"])
@time_func_latency
async def encode_request(obj: EmbeddingReqInput, request: Request):
"""Handle an embedding request."""
......@@ -295,10 +282,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request):
)
app.post("/encode")(encode_request)
app.put("/encode")(encode_request)
@app.api_route("/encode", methods=["POST", "PUT"])
@time_func_latency
async def classify_request(obj: EmbeddingReqInput, request: Request):
"""Handle a reward model request. Now the arguments and return values are the same as embedding models."""
......@@ -311,10 +295,6 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
)
app.post("/classify")(classify_request)
app.put("/classify")(classify_request)
##### OpenAI-compatible API endpoints #####
......@@ -392,11 +372,11 @@ def launch_engine(
server_args: ServerArgs,
):
"""
Launch the Tokenizer Manager in the main process, the Scheduler in a subprocess, and the Detokenizer Manager in another subprocess.
Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
"""
global tokenizer_manager
global _max_total_num_tokens
global scheduler_info
# Configure global environment
configure_logger(server_args)
......@@ -462,8 +442,8 @@ def launch_engine(
if server_args.chat_template:
load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
# Wait for model to finish loading & get max token nums
scheduler_info = []
# Wait for model to finish loading
scheduler_infos = []
for i in range(len(scheduler_pipe_readers)):
data = scheduler_pipe_readers[i].recv()
......@@ -471,10 +451,10 @@ def launch_engine(
raise RuntimeError(
"Initialization failed. Please see the error messages above."
)
scheduler_info.append(data)
scheduler_infos.append(data)
# Assume all schedulers have same max_total_num_tokens
_max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
scheduler_info = scheduler_infos[0]
def launch_server(
......@@ -488,12 +468,12 @@ def launch_server(
1. HTTP server: A FastAPI server that routes requests to the engine.
2. SRT engine:
1. Tokenizer Manager: Tokenizes the requests and sends them to the scheduler.
1. TokenizerManager: Tokenizes the requests and sends them to the scheduler.
2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager.
3. Detokenizer Manager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.
3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.
Note:
1. The HTTP server and Tokenizer Manager both run in the main process.
1. The HTTP server and TokenizerManager both run in the main process.
2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library.
"""
launch_engine(server_args=server_args)
......@@ -502,7 +482,7 @@ def launch_server(
if server_args.api_key:
add_api_key_middleware(app, server_args.api_key)
# add prometheus middleware
# Add prometheus middleware
if server_args.enable_metrics:
add_prometheus_middleware(app)
enable_func_timer()
......@@ -514,7 +494,7 @@ def launch_server(
t.start()
try:
# Listen for HTTP requests
# Update logging configs
LOGGING_CONFIG["formatters"]["default"][
"fmt"
] = "[%(asctime)s] %(levelprefix)s %(message)s"
......@@ -523,6 +503,8 @@ def launch_server(
"fmt"
] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s'
LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
# Listen for HTTP requests
uvicorn.run(
app,
host=server_args.host,
......@@ -538,8 +520,7 @@ def launch_server(
async def _get_server_info():
return {
**dataclasses.asdict(tokenizer_manager.server_args), # server args
"memory_pool_size": await tokenizer_manager.get_memory_pool_size(), # memory pool size
"max_total_num_tokens": _max_total_num_tokens, # max total num tokens
**scheduler_info,
"version": __version__,
}
......@@ -645,6 +626,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
kill_process_tree(os.getpid())
return
# Debug print
# logger.info(f"{res.json()=}")
logger.info("The server is fired up and ready to roll!")
......@@ -821,18 +803,11 @@ class Engine:
launching the HTTP server adds unnecessary complexity or overhead,
"""
def __init__(self, *args, **kwargs):
def __init__(self, log_level: str = "error", *args, **kwargs):
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
atexit.register(self.shutdown)
# runtime server default log level is log
# offline engine works in scripts, so we set it to error
if "log_level" not in kwargs:
kwargs["log_level"] = "error"
server_args = ServerArgs(*args, **kwargs)
server_args = ServerArgs(*args, log_level=log_level, **kwargs)
launch_engine(server_args=server_args)
def generate(
......@@ -955,5 +930,11 @@ class Engine:
loop = asyncio.get_event_loop()
return loop.run_until_complete(encode_request(obj, None))
def start_profile(self):
tokenizer_manager.start_profile()
def stop_profile(self):
tokenizer_manager.stop_profile()
async def get_server_info(self):
return await _get_server_info()
......@@ -220,9 +220,6 @@ class TestSRTEndpoint(unittest.TestCase):
max_total_num_tokens = response_json["max_total_num_tokens"]
self.assertIsInstance(max_total_num_tokens, int)
memory_pool_size = response_json["memory_pool_size"]
self.assertIsInstance(memory_pool_size, int)
attention_backend = response_json["attention_backend"]
self.assertIsInstance(attention_backend, str)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment