Unverified Commit 90227800 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Minor] Improve the function organization in TokenizerManager & improve loggers (#1208)

parent 30b4f771
......@@ -6,7 +6,7 @@ Achieving a large batch size is the most important thing for attaining high thro
When the server is running at full load, look for the following in the log:
```[gpu=0] Decode batch. #running-req: 233, #token: 370959, token usage: 0.82, gen throughput (token/s): 4594.01, #queue-req: 417```
```Decode batch. #running-req: 233, #token: 370959, token usage: 0.82, gen throughput (token/s): 4594.01, #queue-req: 417```
### Tune Your Request Submission Speed
`#queue-req` indicates the number of requests in the queue. If you frequently see `#queue-req == 0`, it suggests you are bottlenecked by the request submission speed.
......
......@@ -142,17 +142,6 @@ def get_tokenizer(
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
kwargs["use_fast"] = False
if (
"llama" in tokenizer_name.lower()
and kwargs.get("use_fast", True)
and tokenizer_name != _FAST_LLAMA_TOKENIZER
):
warnings.warn(
"For some LLaMA V1 models, initializing the fast tokenizer may "
"take a long time. To reduce the initialization time, consider "
f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
"tokenizer."
)
try:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
......
......@@ -35,7 +35,7 @@ from sglang.srt.managers.io_struct import (
TokenizedGenerateReqInput,
)
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import kill_parent_process
from sglang.srt.utils import configure_logger, kill_parent_process
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
......@@ -193,10 +193,7 @@ def start_controller_process(
):
"""Start a controller process."""
logging.basicConfig(
level=getattr(logging, server_args.log_level.upper()),
format="%(message)s",
)
configure_logger(server_args)
try:
controller = ControllerMulti(server_args, port_args, model_overide_args)
......
......@@ -27,7 +27,7 @@ from sglang.srt.managers.tp_worker import (
launch_tp_servers,
)
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import kill_parent_process
from sglang.srt.utils import configure_logger, kill_parent_process
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
......@@ -52,7 +52,7 @@ class ControllerSingle:
self.dp_worker_id = dp_worker_id
self.mp_queue = mp_queue
# Init communication
# Init inter-process communication
context = zmq.Context(2)
if not self.is_dp_worker:
......@@ -133,11 +133,11 @@ def start_controller_process(
queue: multiprocessing.connection.Connection = None,
):
"""Start a controller process."""
logging.basicConfig(
level=getattr(logging, server_args.log_level.upper()),
format="%(message)s",
)
if is_data_parallel_worker:
logger_prefix = f" DP{dp_worker_id} TP0"
else:
logger_prefix = " TP0"
configure_logger(server_args, prefix=logger_prefix)
if not is_data_parallel_worker:
tp_size_local = server_args.tp_size // server_args.nnodes
......
......@@ -56,6 +56,7 @@ class DetokenizerManager:
server_args: ServerArgs,
port_args: PortArgs,
):
# Init inter-process communication
context = zmq.asyncio.Context(2)
self.recv_from_router = context.socket(zmq.PULL)
self.recv_from_router.bind(f"tcp://127.0.0.1:{port_args.detokenizer_port}")
......@@ -75,10 +76,13 @@ class DetokenizerManager:
self.decode_status = {}
async def handle_loop(self):
"""The event loop that handles requests"""
while True:
recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
recv_obj = await self.recv_from_router.recv_pyobj()
if isinstance(recv_obj, BatchEmbeddingOut):
# If it is embedding model, no detokenization is needed.
self.send_to_tokenizer.send_pyobj(
BatchEmbeddingOut(
rids=recv_obj.rids,
......@@ -88,19 +92,18 @@ class DetokenizerManager:
)
)
continue
if isinstance(recv_obj, UpdateWeightReqOutput):
elif isinstance(recv_obj, UpdateWeightReqOutput):
# If it is a weight update request, no detokenization is needed.
self.send_to_tokenizer.send_pyobj(recv_obj)
continue
elif self.tokenizer is None:
# If the tokenizer is skipped, no detokenization is needed
self.send_to_tokenizer.send_pyobj(recv_obj)
continue
assert isinstance(recv_obj, BatchTokenIDOut)
bs = len(recv_obj.rids)
if self.tokenizer is None:
# Send BatchTokenIDOut if no tokenizer init'ed.
self.send_to_tokenizer.send_pyobj(recv_obj)
continue
# Initialize decode status
read_ids, surr_ids = [], []
for i in range(bs):
......@@ -134,6 +137,7 @@ class DetokenizerManager:
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
)
# Incremental decoding
output_strs = []
for i in range(bs):
s = self.decode_status[recv_obj.rids[i]]
......
......@@ -21,7 +21,7 @@ import dataclasses
import logging
import multiprocessing as mp
import os
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import transformers
......@@ -80,6 +80,7 @@ class TokenizerManager:
):
self.server_args = server_args
# Init inter-process communication
context = zmq.asyncio.Context(2)
self.recv_from_detokenizer = context.socket(zmq.PULL)
self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
......@@ -87,6 +88,7 @@ class TokenizerManager:
self.send_to_router = context.socket(zmq.PUSH)
self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.controller_port}")
# Read model args
self.model_path = server_args.model_path
self.served_model_name = server_args.served_model_name
self.hf_config = get_config(
......@@ -104,6 +106,7 @@ class TokenizerManager:
else:
self.context_len = get_context_length(self.hf_config)
# Create tokenizer
if server_args.skip_tokenizer_init:
self.tokenizer = self.processor = None
else:
......@@ -127,6 +130,7 @@ class TokenizerManager:
trust_remote_code=server_args.trust_remote_code,
)
# Store states
self.to_create_loop = True
self.rid_to_state: Dict[str, ReqState] = {}
......@@ -134,63 +138,6 @@ class TokenizerManager:
self.model_update_lock = asyncio.Lock()
self.model_update_result = None
async def get_pixel_values(self, image_data, aspect_ratio=None):
aspect_ratio = (
getattr(self.hf_config, "image_aspect_ratio", None)
if aspect_ratio is None
else aspect_ratio
)
grid_pinpoints = (
self.hf_config.image_grid_pinpoints
if hasattr(self.hf_config, "image_grid_pinpoints")
and "anyres" in aspect_ratio
else None
)
if isinstance(image_data, list) and len(image_data) > 0:
pixel_values, image_hash, image_size = [], [], []
if len(image_data) > 1:
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
for img_data in image_data:
pixel_v, image_h, image_s = await self._process_single_image(
img_data, aspect_ratio, grid_pinpoints
)
pixel_values.append(pixel_v)
image_hash.append(image_h)
image_size.append(image_s)
pixel_values = np.stack(pixel_values, axis=0)
else:
pixel_values, image_hash, image_size = await self._process_single_image(
image_data[0], aspect_ratio, grid_pinpoints
)
image_hash = [image_hash]
image_size = [image_size]
elif isinstance(image_data, str):
pixel_values, image_hash, image_size = await self._process_single_image(
image_data, aspect_ratio, grid_pinpoints
)
image_hash = [image_hash]
image_size = [image_size]
else:
pixel_values, image_hash, image_size = None, None, None
return pixel_values, image_hash, image_size
async def _process_single_image(self, image_data, aspect_ratio, grid_pinpoints):
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
get_pixel_values,
image_data,
aspect_ratio,
grid_pinpoints,
)
else:
return get_pixel_values(
image_data, aspect_ratio, grid_pinpoints, self.processor
)
async def generate_request(
self, obj: Union[GenerateReqInput, EmbeddingReqInput], request=None
):
......@@ -198,7 +145,7 @@ class TokenizerManager:
self.create_handle_loop()
while self.model_update_lock.locked():
await asyncio.sleep(0)
await asyncio.sleep(0.001)
obj.post_init()
is_single = obj.is_single
......@@ -214,8 +161,8 @@ class TokenizerManager:
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
request,
index=None,
is_cache_for_prefill=False,
index: Optional[int] = None,
is_cache_for_prefill: Optional[bool] = False,
):
if not is_cache_for_prefill: # The normal case with a single prompt
not_use_index = index is None
......@@ -235,7 +182,7 @@ class TokenizerManager:
)
if self.is_generation:
pixel_values, image_hash, image_size = await self.get_pixel_values(
pixel_values, image_hash, image_size = await self._get_pixel_values(
obj.image_data
)
return_logprob = (
......@@ -345,7 +292,7 @@ class TokenizerManager:
parallel_sample_num = obj.parallel_sample_num
if parallel_sample_num != 1:
# Send prefill requests to cache the common input
# Send prefill requests to cache the common prefix
parallel_sample_num += 1
input_id_result = [] if obj.input_ids is None else None
for i in range(batch_size):
......@@ -436,7 +383,6 @@ class TokenizerManager:
)
# Then process the responses based on streaming option
is_stream = hasattr(obj, "stream") and obj.stream
tasks = [asyncio.create_task(gen.__anext__()) for gen in generators]
......@@ -482,9 +428,9 @@ class TokenizerManager:
async def _get_pixel_values(self, image_data):
if isinstance(image_data, list) and len(image_data) > 0:
return await self.get_pixel_values(image_data[0])
return await self._get_pixel_values_internal(image_data[0])
elif isinstance(image_data, str):
return await self.get_pixel_values(image_data)
return await self._get_pixel_values_internal(image_data)
else:
return None, None, None
......@@ -563,6 +509,13 @@ class TokenizerManager:
req = FlushCacheReq()
self.send_to_router.send_pyobj(req)
def abort_request(self, rid: str):
if rid not in self.rid_to_state:
return
del self.rid_to_state[rid]
req = AbortReq(rid)
self.send_to_router.send_pyobj(req)
async def update_weights(self, obj: UpdateWeightReqInput, request):
if self.to_create_loop:
self.create_handle_loop()
......@@ -587,13 +540,6 @@ class TokenizerManager:
else:
return False, "Another update is in progress. Please try again later."
def abort_request(self, rid: str):
if rid not in self.rid_to_state:
return
del self.rid_to_state[rid]
req = AbortReq(rid)
self.send_to_router.send_pyobj(req)
def create_abort_task(self, obj: GenerateReqInput):
# Abort the request if the client is disconnected.
async def abort_request():
......@@ -617,6 +563,8 @@ class TokenizerManager:
loop.create_task(self.handle_loop())
async def handle_loop(self):
"""The event loop that handles requests"""
while True:
recv_obj: Union[
BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightReqOutput
......@@ -713,11 +661,69 @@ class TokenizerManager:
)
return top_logprobs
async def _get_pixel_values_internal(self, image_data, aspect_ratio=None):
aspect_ratio = (
getattr(self.hf_config, "image_aspect_ratio", None)
if aspect_ratio is None
else aspect_ratio
)
grid_pinpoints = (
self.hf_config.image_grid_pinpoints
if hasattr(self.hf_config, "image_grid_pinpoints")
and "anyres" in aspect_ratio
else None
)
if isinstance(image_data, list) and len(image_data) > 0:
pixel_values, image_hash, image_size = [], [], []
if len(image_data) > 1:
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
for img_data in image_data:
pixel_v, image_h, image_s = await self._process_single_image(
img_data, aspect_ratio, grid_pinpoints
)
pixel_values.append(pixel_v)
image_hash.append(image_h)
image_size.append(image_s)
pixel_values = np.stack(pixel_values, axis=0)
else:
pixel_values, image_hash, image_size = await self._process_single_image(
image_data[0], aspect_ratio, grid_pinpoints
)
image_hash = [image_hash]
image_size = [image_size]
elif isinstance(image_data, str):
pixel_values, image_hash, image_size = await self._process_single_image(
image_data, aspect_ratio, grid_pinpoints
)
image_hash = [image_hash]
image_size = [image_size]
else:
pixel_values, image_hash, image_size = None, None, None
return pixel_values, image_hash, image_size
async def _process_single_image(self, image_data, aspect_ratio, grid_pinpoints):
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
_process_single_image_task,
image_data,
aspect_ratio,
grid_pinpoints,
)
else:
return _process_single_image_task(
image_data, aspect_ratio, grid_pinpoints, self.processor
)
global global_processor
def init_global_processor(server_args: ServerArgs):
"""Init the global processor for multi modal models."""
global global_processor
transformers.logging.set_verbosity_error()
global_processor = get_processor(
......@@ -727,7 +733,7 @@ def init_global_processor(server_args: ServerArgs):
)
def get_pixel_values(
def _process_single_image_task(
image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None
):
try:
......@@ -759,4 +765,4 @@ def get_pixel_values(
pixel_values = pixel_values.astype(np.float16)
return pixel_values, image_hash, image.size
except Exception:
print("Exception in TokenizerManager:\n" + get_exception_traceback())
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
......@@ -56,6 +56,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
configure_logger,
is_multimodal_model,
set_random_seed,
suppress_other_loggers,
......@@ -145,7 +146,6 @@ class ModelTpServer:
# Print info
logger.info(
f"[gpu={self.gpu_id}] "
f"max_total_num_tokens={self.max_total_num_tokens}, "
f"max_prefill_tokens={self.max_prefill_tokens}, "
f"max_running_requests={self.max_running_requests}, "
......@@ -284,7 +284,7 @@ class ModelTpServer:
self.num_generated_tokens = 0
self.last_stats_tic = time.time()
logger.info(
f"[gpu={self.gpu_id}] Decode batch. "
f"Decode batch. "
f"#running-req: {len(self.running_batch.reqs)}, "
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
......@@ -443,7 +443,7 @@ class ModelTpServer:
if num_mixed_running > 0:
logger.info(
f"[gpu={self.gpu_id}] Prefill batch"
f"Prefill batch"
f"(mixed #running-req: {num_mixed_running}). "
f"#new-seq: {len(can_run_list)}, "
f"#new-token: {adder.log_input_tokens}, "
......@@ -453,7 +453,7 @@ class ModelTpServer:
)
else:
logger.info(
f"[gpu={self.gpu_id}] Prefill batch. "
f"Prefill batch. "
f"#new-seq: {len(can_run_list)}, "
f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, "
......@@ -631,7 +631,7 @@ class ModelTpServer:
self.new_token_ratio = new_token_ratio
logger.info(
"decode out of memory happened, "
"Decode out of memory happened. "
f"#retracted_reqs: {len(retracted_reqs)}, "
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
)
......@@ -848,7 +848,9 @@ def run_tp_server(
nccl_port: int,
model_overide_args: dict,
):
"""Run a tensor parallel server."""
"""Run a tensor parallel model server."""
configure_logger(server_args, prefix=f" TP{tp_rank}")
try:
model_server = ModelTpServer(
gpu_id,
......
......@@ -109,7 +109,7 @@ class ModelRunner:
def init_torch_distributed(self):
# Init torch distributed
torch.cuda.set_device(self.gpu_id)
logger.info(f"[gpu={self.gpu_id}] Init nccl begin.")
logger.info("Init nccl begin.")
if not self.server_args.enable_p2p_check:
monkey_patch_vllm_p2p_access_check(self.gpu_id)
......@@ -152,8 +152,7 @@ class ModelRunner:
def load_model(self):
logger.info(
f"[gpu={self.gpu_id}] Load weight begin. "
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
f"Load weight begin. avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
)
if torch.cuda.get_device_capability()[0] < 8:
logger.info(
......@@ -208,7 +207,7 @@ class ModelRunner:
)
logger.info(
f"[gpu={self.gpu_id}] Load weight end. "
f"Load weight end. "
f"type={type(self.model).__name__}, "
f"dtype={self.dtype}, "
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
......@@ -224,7 +223,7 @@ class ModelRunner:
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
logger.info(
f"[gpu={self.gpu_id}] Update weights begin. "
f"Update weights begin. "
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
)
......@@ -298,7 +297,7 @@ class ModelRunner:
self.load_config = load_config
self.model_config.path = model_path
logger.info(f"[gpu={self.gpu_id}] Update weights end.")
logger.info("Update weights end.")
return True, "Succeeded to update model weights"
def profile_max_num_token(self, total_gpu_memory: int):
......@@ -387,7 +386,7 @@ class ModelRunner:
layer_num=self.model_config.num_hidden_layers,
)
logger.info(
f"[gpu={self.gpu_id}] Memory pool end. "
f"Memory pool end. "
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
)
......@@ -473,9 +472,7 @@ class ModelRunner:
self.cuda_graph_runner = None
return
logger.info(
f"[gpu={self.gpu_id}] Capture cuda graph begin. This can take up to several minutes."
)
logger.info("Capture cuda graph begin. This can take up to several minutes.")
if self.server_args.disable_cuda_graph_padding:
batch_size_list = list(range(1, 32)) + [64, 128]
......
......@@ -123,7 +123,7 @@ def create_streaming_error_response(
def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg):
global chat_template_name
print(f"Use chat template: {chat_template_arg}")
logger.info(f"Use chat template: {chat_template_arg}")
if not chat_template_exists(chat_template_arg):
if not os.path.exists(chat_template_arg):
raise RuntimeError(
......@@ -355,7 +355,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
}
except Exception as e:
print("error in SGLang:", e)
logger.error("error in SGLang:", e)
# Update batch status to "failed"
retrieve_batch = batch_storage[batch_id]
retrieve_batch.status = "failed"
......
......@@ -74,6 +74,7 @@ from sglang.srt.utils import (
add_api_key_middleware,
allocate_init_ports,
assert_pkg_version,
configure_logger,
enable_show_time_cost,
kill_child_process,
maybe_set_triton_cache_manager,
......@@ -270,15 +271,12 @@ def launch_server(
"""Launch an HTTP server."""
global tokenizer_manager
logging.basicConfig(
level=getattr(logging, server_args.log_level.upper()),
format="%(message)s",
)
configure_logger(server_args)
server_args.check_server_args()
_set_envs_and_config(server_args)
# Allocate ports
# Allocate ports for inter-process communications
server_args.port, server_args.additional_ports = allocate_init_ports(
server_args.port,
server_args.additional_ports,
......
......@@ -418,7 +418,7 @@ class ServerArgs:
parser.add_argument(
"--enable-mixed-chunk",
action="store_true",
help="Enabling mixing prefill and decode in a chunked batch.",
help="Enabling mixing prefill and decode in a batch when using chunked prefill.",
)
parser.add_argument(
"--enable-torch-compile",
......
......@@ -692,7 +692,7 @@ def monkey_patch_vllm_qvk_linear_loader():
setattr(QKVParallelLinear, "weight_loader", weight_loader_srt)
def add_api_key_middleware(app, api_key):
def add_api_key_middleware(app, api_key: str):
@app.middleware("http")
async def authentication(request, call_next):
if request.method == "OPTIONS":
......@@ -704,7 +704,7 @@ def add_api_key_middleware(app, api_key):
return await call_next(request)
def prepare_model(model_path):
def prepare_model(model_path: str):
if "SGLANG_USE_MODELSCOPE" in os.environ:
if not os.path.exists(model_path):
from modelscope import snapshot_download
......@@ -713,7 +713,7 @@ def prepare_model(model_path):
return model_path
def prepare_tokenizer(tokenizer_path):
def prepare_tokenizer(tokenizer_path: str):
if "SGLANG_USE_MODELSCOPE" in os.environ:
if not os.path.exists(tokenizer_path):
from modelscope import snapshot_download
......@@ -722,3 +722,13 @@ def prepare_tokenizer(tokenizer_path):
tokenizer_path, ignore_patterns=["*.bin", "*.safetensors"]
)
return tokenizer_path
def configure_logger(server_args, prefix: str = ""):
format = f"[%(asctime)s{prefix}] %(message)s"
logging.basicConfig(
level=getattr(logging, server_args.log_level.upper()),
format=format,
datefmt="%H:%M:%S",
force=True,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment