Unverified Commit fe97a2d4 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Simplify tokenizer manager (#2254)

parent 8b48496a
...@@ -21,14 +21,13 @@ from typing import Dict, List, Optional, Tuple ...@@ -21,14 +21,13 @@ from typing import Dict, List, Optional, Tuple
import numpy as np import numpy as np
from sglang.api import Engine
from sglang.bench_serving import ( from sglang.bench_serving import (
get_dataset, get_dataset,
get_tokenizer, get_tokenizer,
sample_random_requests, sample_random_requests,
set_ulimit, set_ulimit,
) )
from sglang.srt.server import Runtime, start_profile, stop_profile from sglang.srt.server import Engine, Runtime
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
...@@ -204,12 +203,12 @@ def throughput_test_once( ...@@ -204,12 +203,12 @@ def throughput_test_once(
st = time.perf_counter() st = time.perf_counter()
if profile: if profile:
start_profile() backend.start_profile()
gen_out = backend.generate(prompt=prompt, sampling_params=sampling_params) gen_out = backend.generate(prompt=prompt, sampling_params=sampling_params)
if profile: if profile:
stop_profile() backend.stop_profile()
monitor_trace_file(os.getenv("SGLANG_TORCH_PROFILER_DIR")) monitor_trace_file(os.getenv("SGLANG_TORCH_PROFILER_DIR"))
latency = time.perf_counter() - st latency = time.perf_counter() - st
......
...@@ -338,7 +338,7 @@ class Qwen2VLImageProcessor(BaseImageProcessor): ...@@ -338,7 +338,7 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
"pixel_values": pixel_values, "pixel_values": pixel_values,
"image_hashes": image_hashes, "image_hashes": image_hashes,
"image_sizes": image_sizes, "image_sizes": image_sizes,
"modalities": request_obj.modalities, "modalities": request_obj.modalities or ["image"],
"image_grid_thws": image_grid_thws, "image_grid_thws": image_grid_thws,
} }
......
...@@ -376,16 +376,6 @@ class ProfileReq(Enum): ...@@ -376,16 +376,6 @@ class ProfileReq(Enum):
STOP_PROFILE = 2 STOP_PROFILE = 2
@dataclass
class GetMemPoolSizeReq:
pass
@dataclass
class GetMemPoolSizeReqOutput:
size: int
@dataclass @dataclass
class OpenSessionReqInput: class OpenSessionReqInput:
capacity_of_str_len: int capacity_of_str_len: int
......
...@@ -38,8 +38,6 @@ from sglang.srt.managers.io_struct import ( ...@@ -38,8 +38,6 @@ from sglang.srt.managers.io_struct import (
BatchTokenIDOut, BatchTokenIDOut,
CloseSessionReqInput, CloseSessionReqInput,
FlushCacheReq, FlushCacheReq,
GetMemPoolSizeReq,
GetMemPoolSizeReqOutput,
OpenSessionReqInput, OpenSessionReqInput,
OpenSessionReqOutput, OpenSessionReqOutput,
ProfileReq, ProfileReq,
...@@ -521,10 +519,6 @@ class Scheduler: ...@@ -521,10 +519,6 @@ class Scheduler:
self.send_to_tokenizer.send_pyobj(OpenSessionReqOutput(session_id)) self.send_to_tokenizer.send_pyobj(OpenSessionReqOutput(session_id))
elif isinstance(recv_req, CloseSessionReqInput): elif isinstance(recv_req, CloseSessionReqInput):
self.close_session(recv_req) self.close_session(recv_req)
elif isinstance(recv_req, GetMemPoolSizeReq):
self.send_to_tokenizer.send_pyobj(
GetMemPoolSizeReqOutput(self.max_total_num_tokens)
)
else: else:
raise ValueError(f"Invalid request: {recv_req}") raise ValueError(f"Invalid request: {recv_req}")
......
...@@ -45,8 +45,6 @@ from sglang.srt.managers.io_struct import ( ...@@ -45,8 +45,6 @@ from sglang.srt.managers.io_struct import (
EmbeddingReqInput, EmbeddingReqInput,
FlushCacheReq, FlushCacheReq,
GenerateReqInput, GenerateReqInput,
GetMemPoolSizeReq,
GetMemPoolSizeReqOutput,
OpenSessionReqInput, OpenSessionReqInput,
OpenSessionReqOutput, OpenSessionReqOutput,
ProfileReq, ProfileReq,
...@@ -218,7 +216,7 @@ class TokenizerManager: ...@@ -218,7 +216,7 @@ class TokenizerManager:
input_ids = obj.input_ids input_ids = obj.input_ids
if self.is_generation: if self.is_generation:
image_inputs = await self.image_processor.process_images_async( image_inputs: Dict = await self.image_processor.process_images_async(
obj.image_data, input_text or input_ids, obj obj.image_data, input_text or input_ids, obj
) )
if image_inputs and "input_ids" in image_inputs: if image_inputs and "input_ids" in image_inputs:
...@@ -406,25 +404,6 @@ class TokenizerManager: ...@@ -406,25 +404,6 @@ class TokenizerManager:
req = ProfileReq.STOP_PROFILE req = ProfileReq.STOP_PROFILE
self.send_to_scheduler.send_pyobj(req) self.send_to_scheduler.send_pyobj(req)
async def get_memory_pool_size(self):
if self.to_create_loop:
self.create_handle_loop()
req = GetMemPoolSizeReq()
self.send_to_scheduler.send_pyobj(req)
self.mem_pool_size = asyncio.Future()
# FIXME: Each request should have its own future instead of using `self.mem_pool_size`.
if self.server_args.dp_size == 1:
res = await self.mem_pool_size
return res.size
else: # self.server_args.dp_size > 1
self.mem_pool_size_tmp = []
res = await self.mem_pool_size
ret = [r.size for r in res]
return ret
async def update_weights( async def update_weights(
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
): ):
...@@ -552,15 +531,6 @@ class TokenizerManager: ...@@ -552,15 +531,6 @@ class TokenizerManager:
if len(self.model_update_tmp) == self.server_args.dp_size: if len(self.model_update_tmp) == self.server_args.dp_size:
self.model_update_result.set_result(self.model_update_tmp) self.model_update_result.set_result(self.model_update_tmp)
continue continue
elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
if self.server_args.dp_size == 1:
self.mem_pool_size.set_result(recv_obj)
else: # self.sever_args.dp_size > 1
self.mem_pool_size_tmp.append(recv_obj)
# set future if the all results are received
if len(self.mem_pool_size_tmp) == self.server_args.dp_size:
self.mem_pool_size.set_result(self.mem_pool_size_tmp)
continue
elif isinstance(recv_obj, OpenSessionReqOutput): elif isinstance(recv_obj, OpenSessionReqOutput):
self.session_futures[recv_obj.session_id].set_result( self.session_futures[recv_obj.session_id].set_result(
recv_obj.session_id recv_obj.session_id
......
...@@ -24,7 +24,6 @@ import logging ...@@ -24,7 +24,6 @@ import logging
import multiprocessing as mp import multiprocessing as mp
import os import os
import signal import signal
import sys
import threading import threading
import time import time
from http import HTTPStatus from http import HTTPStatus
...@@ -94,7 +93,7 @@ logger = logging.getLogger(__name__) ...@@ -94,7 +93,7 @@ logger = logging.getLogger(__name__)
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
# Fast API
app = FastAPI() app = FastAPI()
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
...@@ -105,7 +104,7 @@ app.add_middleware( ...@@ -105,7 +104,7 @@ app.add_middleware(
) )
tokenizer_manager: TokenizerManager = None tokenizer_manager: TokenizerManager = None
_max_total_num_tokens = None scheduler_info: Dict = None
##### Native API endpoints ##### ##### Native API endpoints #####
...@@ -171,16 +170,6 @@ async def flush_cache(): ...@@ -171,16 +170,6 @@ async def flush_cache():
) )
def start_profile():
"""Start profiling."""
tokenizer_manager.start_profile()
def stop_profile():
"""Stop profiling."""
tokenizer_manager.stop_profile()
@app.get("/start_profile") @app.get("/start_profile")
@app.post("/start_profile") @app.post("/start_profile")
async def start_profile_async(): async def start_profile_async():
...@@ -245,6 +234,8 @@ async def close_session(obj: CloseSessionReqInput, request: Request): ...@@ -245,6 +234,8 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
) )
# fastapi implicitly converts json in the request to obj (dataclass)
@app.api_route("/generate", methods=["POST", "PUT"])
@time_func_latency @time_func_latency
async def generate_request(obj: GenerateReqInput, request: Request): async def generate_request(obj: GenerateReqInput, request: Request):
"""Handle a generate request.""" """Handle a generate request."""
...@@ -278,11 +269,7 @@ async def generate_request(obj: GenerateReqInput, request: Request): ...@@ -278,11 +269,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
) )
# fastapi implicitly converts json in the request to obj (dataclass) @app.api_route("/encode", methods=["POST", "PUT"])
app.post("/generate")(generate_request)
app.put("/generate")(generate_request)
@time_func_latency @time_func_latency
async def encode_request(obj: EmbeddingReqInput, request: Request): async def encode_request(obj: EmbeddingReqInput, request: Request):
"""Handle an embedding request.""" """Handle an embedding request."""
...@@ -295,10 +282,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request): ...@@ -295,10 +282,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request):
) )
app.post("/encode")(encode_request) @app.api_route("/encode", methods=["POST", "PUT"])
app.put("/encode")(encode_request)
@time_func_latency @time_func_latency
async def classify_request(obj: EmbeddingReqInput, request: Request): async def classify_request(obj: EmbeddingReqInput, request: Request):
"""Handle a reward model request. Now the arguments and return values are the same as embedding models.""" """Handle a reward model request. Now the arguments and return values are the same as embedding models."""
...@@ -311,10 +295,6 @@ async def classify_request(obj: EmbeddingReqInput, request: Request): ...@@ -311,10 +295,6 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
) )
app.post("/classify")(classify_request)
app.put("/classify")(classify_request)
##### OpenAI-compatible API endpoints ##### ##### OpenAI-compatible API endpoints #####
...@@ -392,11 +372,11 @@ def launch_engine( ...@@ -392,11 +372,11 @@ def launch_engine(
server_args: ServerArgs, server_args: ServerArgs,
): ):
""" """
Launch the Tokenizer Manager in the main process, the Scheduler in a subprocess, and the Detokenizer Manager in another subprocess. Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
""" """
global tokenizer_manager global tokenizer_manager
global _max_total_num_tokens global scheduler_info
# Configure global environment # Configure global environment
configure_logger(server_args) configure_logger(server_args)
...@@ -462,8 +442,8 @@ def launch_engine( ...@@ -462,8 +442,8 @@ def launch_engine(
if server_args.chat_template: if server_args.chat_template:
load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template) load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
# Wait for model to finish loading & get max token nums # Wait for model to finish loading
scheduler_info = [] scheduler_infos = []
for i in range(len(scheduler_pipe_readers)): for i in range(len(scheduler_pipe_readers)):
data = scheduler_pipe_readers[i].recv() data = scheduler_pipe_readers[i].recv()
...@@ -471,10 +451,10 @@ def launch_engine( ...@@ -471,10 +451,10 @@ def launch_engine(
raise RuntimeError( raise RuntimeError(
"Initialization failed. Please see the error messages above." "Initialization failed. Please see the error messages above."
) )
scheduler_info.append(data) scheduler_infos.append(data)
# Assume all schedulers have same max_total_num_tokens # Assume all schedulers have same max_total_num_tokens
_max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"] scheduler_info = scheduler_infos[0]
def launch_server( def launch_server(
...@@ -488,12 +468,12 @@ def launch_server( ...@@ -488,12 +468,12 @@ def launch_server(
1. HTTP server: A FastAPI server that routes requests to the engine. 1. HTTP server: A FastAPI server that routes requests to the engine.
2. SRT engine: 2. SRT engine:
1. Tokenizer Manager: Tokenizes the requests and sends them to the scheduler. 1. TokenizerManager: Tokenizes the requests and sends them to the scheduler.
2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager. 2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager.
3. Detokenizer Manager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager. 3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.
Note: Note:
1. The HTTP server and Tokenizer Manager both run in the main process. 1. The HTTP server and TokenizerManager both run in the main process.
2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library. 2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library.
""" """
launch_engine(server_args=server_args) launch_engine(server_args=server_args)
...@@ -502,7 +482,7 @@ def launch_server( ...@@ -502,7 +482,7 @@ def launch_server(
if server_args.api_key: if server_args.api_key:
add_api_key_middleware(app, server_args.api_key) add_api_key_middleware(app, server_args.api_key)
# add prometheus middleware # Add prometheus middleware
if server_args.enable_metrics: if server_args.enable_metrics:
add_prometheus_middleware(app) add_prometheus_middleware(app)
enable_func_timer() enable_func_timer()
...@@ -514,7 +494,7 @@ def launch_server( ...@@ -514,7 +494,7 @@ def launch_server(
t.start() t.start()
try: try:
# Listen for HTTP requests # Update logging configs
LOGGING_CONFIG["formatters"]["default"][ LOGGING_CONFIG["formatters"]["default"][
"fmt" "fmt"
] = "[%(asctime)s] %(levelprefix)s %(message)s" ] = "[%(asctime)s] %(levelprefix)s %(message)s"
...@@ -523,6 +503,8 @@ def launch_server( ...@@ -523,6 +503,8 @@ def launch_server(
"fmt" "fmt"
] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s' ] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s'
LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S" LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
# Listen for HTTP requests
uvicorn.run( uvicorn.run(
app, app,
host=server_args.host, host=server_args.host,
...@@ -538,8 +520,7 @@ def launch_server( ...@@ -538,8 +520,7 @@ def launch_server(
async def _get_server_info(): async def _get_server_info():
return { return {
**dataclasses.asdict(tokenizer_manager.server_args), # server args **dataclasses.asdict(tokenizer_manager.server_args), # server args
"memory_pool_size": await tokenizer_manager.get_memory_pool_size(), # memory pool size **scheduler_info,
"max_total_num_tokens": _max_total_num_tokens, # max total num tokens
"version": __version__, "version": __version__,
} }
...@@ -645,6 +626,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer): ...@@ -645,6 +626,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
kill_process_tree(os.getpid()) kill_process_tree(os.getpid())
return return
# Debug print
# logger.info(f"{res.json()=}") # logger.info(f"{res.json()=}")
logger.info("The server is fired up and ready to roll!") logger.info("The server is fired up and ready to roll!")
...@@ -821,18 +803,11 @@ class Engine: ...@@ -821,18 +803,11 @@ class Engine:
launching the HTTP server adds unnecessary complexity or overhead, launching the HTTP server adds unnecessary complexity or overhead,
""" """
def __init__(self, *args, **kwargs): def __init__(self, log_level: str = "error", *args, **kwargs):
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown() # before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
atexit.register(self.shutdown) atexit.register(self.shutdown)
# runtime server default log level is log server_args = ServerArgs(*args, log_level=log_level, **kwargs)
# offline engine works in scripts, so we set it to error
if "log_level" not in kwargs:
kwargs["log_level"] = "error"
server_args = ServerArgs(*args, **kwargs)
launch_engine(server_args=server_args) launch_engine(server_args=server_args)
def generate( def generate(
...@@ -955,5 +930,11 @@ class Engine: ...@@ -955,5 +930,11 @@ class Engine:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return loop.run_until_complete(encode_request(obj, None)) return loop.run_until_complete(encode_request(obj, None))
def start_profile(self):
tokenizer_manager.start_profile()
def stop_profile(self):
tokenizer_manager.stop_profile()
async def get_server_info(self): async def get_server_info(self):
return await _get_server_info() return await _get_server_info()
...@@ -220,9 +220,6 @@ class TestSRTEndpoint(unittest.TestCase): ...@@ -220,9 +220,6 @@ class TestSRTEndpoint(unittest.TestCase):
max_total_num_tokens = response_json["max_total_num_tokens"] max_total_num_tokens = response_json["max_total_num_tokens"]
self.assertIsInstance(max_total_num_tokens, int) self.assertIsInstance(max_total_num_tokens, int)
memory_pool_size = response_json["memory_pool_size"]
self.assertIsInstance(memory_pool_size, int)
attention_backend = response_json["attention_backend"] attention_backend = response_json["attention_backend"]
self.assertIsInstance(attention_backend, str) self.assertIsInstance(attention_backend, str)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment