Unverified Commit 90227800 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Minor] Improve the function organization in TokenizerManager & improve loggers (#1208)

parent 30b4f771
...@@ -6,7 +6,7 @@ Achieving a large batch size is the most important thing for attaining high thro ...@@ -6,7 +6,7 @@ Achieving a large batch size is the most important thing for attaining high thro
When the server is running at full load, look for the following in the log: When the server is running at full load, look for the following in the log:
```[gpu=0] Decode batch. #running-req: 233, #token: 370959, token usage: 0.82, gen throughput (token/s): 4594.01, #queue-req: 417``` ```Decode batch. #running-req: 233, #token: 370959, token usage: 0.82, gen throughput (token/s): 4594.01, #queue-req: 417```
### Tune Your Request Submission Speed ### Tune Your Request Submission Speed
`#queue-req` indicates the number of requests in the queue. If you frequently see `#queue-req == 0`, it suggests you are bottlenecked by the request submission speed. `#queue-req` indicates the number of requests in the queue. If you frequently see `#queue-req == 0`, it suggests you are bottlenecked by the request submission speed.
......
...@@ -142,17 +142,6 @@ def get_tokenizer( ...@@ -142,17 +142,6 @@ def get_tokenizer(
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
kwargs["use_fast"] = False kwargs["use_fast"] = False
if (
"llama" in tokenizer_name.lower()
and kwargs.get("use_fast", True)
and tokenizer_name != _FAST_LLAMA_TOKENIZER
):
warnings.warn(
"For some LLaMA V1 models, initializing the fast tokenizer may "
"take a long time. To reduce the initialization time, consider "
f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
"tokenizer."
)
try: try:
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name, tokenizer_name,
......
...@@ -35,7 +35,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -35,7 +35,7 @@ from sglang.srt.managers.io_struct import (
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
) )
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import kill_parent_process from sglang.srt.utils import configure_logger, kill_parent_process
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -193,10 +193,7 @@ def start_controller_process( ...@@ -193,10 +193,7 @@ def start_controller_process(
): ):
"""Start a controller process.""" """Start a controller process."""
logging.basicConfig( configure_logger(server_args)
level=getattr(logging, server_args.log_level.upper()),
format="%(message)s",
)
try: try:
controller = ControllerMulti(server_args, port_args, model_overide_args) controller = ControllerMulti(server_args, port_args, model_overide_args)
......
...@@ -27,7 +27,7 @@ from sglang.srt.managers.tp_worker import ( ...@@ -27,7 +27,7 @@ from sglang.srt.managers.tp_worker import (
launch_tp_servers, launch_tp_servers,
) )
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import kill_parent_process from sglang.srt.utils import configure_logger, kill_parent_process
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -52,7 +52,7 @@ class ControllerSingle: ...@@ -52,7 +52,7 @@ class ControllerSingle:
self.dp_worker_id = dp_worker_id self.dp_worker_id = dp_worker_id
self.mp_queue = mp_queue self.mp_queue = mp_queue
# Init communication # Init inter-process communication
context = zmq.Context(2) context = zmq.Context(2)
if not self.is_dp_worker: if not self.is_dp_worker:
...@@ -133,11 +133,11 @@ def start_controller_process( ...@@ -133,11 +133,11 @@ def start_controller_process(
queue: multiprocessing.connection.Connection = None, queue: multiprocessing.connection.Connection = None,
): ):
"""Start a controller process.""" """Start a controller process."""
if is_data_parallel_worker:
logging.basicConfig( logger_prefix = f" DP{dp_worker_id} TP0"
level=getattr(logging, server_args.log_level.upper()), else:
format="%(message)s", logger_prefix = " TP0"
) configure_logger(server_args, prefix=logger_prefix)
if not is_data_parallel_worker: if not is_data_parallel_worker:
tp_size_local = server_args.tp_size // server_args.nnodes tp_size_local = server_args.tp_size // server_args.nnodes
......
...@@ -56,6 +56,7 @@ class DetokenizerManager: ...@@ -56,6 +56,7 @@ class DetokenizerManager:
server_args: ServerArgs, server_args: ServerArgs,
port_args: PortArgs, port_args: PortArgs,
): ):
# Init inter-process communication
context = zmq.asyncio.Context(2) context = zmq.asyncio.Context(2)
self.recv_from_router = context.socket(zmq.PULL) self.recv_from_router = context.socket(zmq.PULL)
self.recv_from_router.bind(f"tcp://127.0.0.1:{port_args.detokenizer_port}") self.recv_from_router.bind(f"tcp://127.0.0.1:{port_args.detokenizer_port}")
...@@ -75,10 +76,13 @@ class DetokenizerManager: ...@@ -75,10 +76,13 @@ class DetokenizerManager:
self.decode_status = {} self.decode_status = {}
async def handle_loop(self): async def handle_loop(self):
"""The event loop that handles requests"""
while True: while True:
recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj() recv_obj = await self.recv_from_router.recv_pyobj()
if isinstance(recv_obj, BatchEmbeddingOut): if isinstance(recv_obj, BatchEmbeddingOut):
# If it is embedding model, no detokenization is needed.
self.send_to_tokenizer.send_pyobj( self.send_to_tokenizer.send_pyobj(
BatchEmbeddingOut( BatchEmbeddingOut(
rids=recv_obj.rids, rids=recv_obj.rids,
...@@ -88,19 +92,18 @@ class DetokenizerManager: ...@@ -88,19 +92,18 @@ class DetokenizerManager:
) )
) )
continue continue
elif isinstance(recv_obj, UpdateWeightReqOutput):
if isinstance(recv_obj, UpdateWeightReqOutput): # If it is a weight update request, no detokenization is needed.
self.send_to_tokenizer.send_pyobj(recv_obj)
continue
elif self.tokenizer is None:
# If the tokenizer is skipped, no detokenization is needed
self.send_to_tokenizer.send_pyobj(recv_obj) self.send_to_tokenizer.send_pyobj(recv_obj)
continue continue
assert isinstance(recv_obj, BatchTokenIDOut) assert isinstance(recv_obj, BatchTokenIDOut)
bs = len(recv_obj.rids) bs = len(recv_obj.rids)
if self.tokenizer is None:
# Send BatchTokenIDOut if no tokenizer init'ed.
self.send_to_tokenizer.send_pyobj(recv_obj)
continue
# Initialize decode status # Initialize decode status
read_ids, surr_ids = [], [] read_ids, surr_ids = [], []
for i in range(bs): for i in range(bs):
...@@ -134,6 +137,7 @@ class DetokenizerManager: ...@@ -134,6 +137,7 @@ class DetokenizerManager:
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0], spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
) )
# Incremental decoding
output_strs = [] output_strs = []
for i in range(bs): for i in range(bs):
s = self.decode_status[recv_obj.rids[i]] s = self.decode_status[recv_obj.rids[i]]
......
...@@ -21,7 +21,7 @@ import dataclasses ...@@ -21,7 +21,7 @@ import dataclasses
import logging import logging
import multiprocessing as mp import multiprocessing as mp
import os import os
from typing import Dict, List, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import transformers import transformers
...@@ -80,6 +80,7 @@ class TokenizerManager: ...@@ -80,6 +80,7 @@ class TokenizerManager:
): ):
self.server_args = server_args self.server_args = server_args
# Init inter-process communication
context = zmq.asyncio.Context(2) context = zmq.asyncio.Context(2)
self.recv_from_detokenizer = context.socket(zmq.PULL) self.recv_from_detokenizer = context.socket(zmq.PULL)
self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}") self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
...@@ -87,6 +88,7 @@ class TokenizerManager: ...@@ -87,6 +88,7 @@ class TokenizerManager:
self.send_to_router = context.socket(zmq.PUSH) self.send_to_router = context.socket(zmq.PUSH)
self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.controller_port}") self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.controller_port}")
# Read model args
self.model_path = server_args.model_path self.model_path = server_args.model_path
self.served_model_name = server_args.served_model_name self.served_model_name = server_args.served_model_name
self.hf_config = get_config( self.hf_config = get_config(
...@@ -104,6 +106,7 @@ class TokenizerManager: ...@@ -104,6 +106,7 @@ class TokenizerManager:
else: else:
self.context_len = get_context_length(self.hf_config) self.context_len = get_context_length(self.hf_config)
# Create tokenizer
if server_args.skip_tokenizer_init: if server_args.skip_tokenizer_init:
self.tokenizer = self.processor = None self.tokenizer = self.processor = None
else: else:
...@@ -127,6 +130,7 @@ class TokenizerManager: ...@@ -127,6 +130,7 @@ class TokenizerManager:
trust_remote_code=server_args.trust_remote_code, trust_remote_code=server_args.trust_remote_code,
) )
# Store states
self.to_create_loop = True self.to_create_loop = True
self.rid_to_state: Dict[str, ReqState] = {} self.rid_to_state: Dict[str, ReqState] = {}
...@@ -134,63 +138,6 @@ class TokenizerManager: ...@@ -134,63 +138,6 @@ class TokenizerManager:
self.model_update_lock = asyncio.Lock() self.model_update_lock = asyncio.Lock()
self.model_update_result = None self.model_update_result = None
async def get_pixel_values(self, image_data, aspect_ratio=None):
aspect_ratio = (
getattr(self.hf_config, "image_aspect_ratio", None)
if aspect_ratio is None
else aspect_ratio
)
grid_pinpoints = (
self.hf_config.image_grid_pinpoints
if hasattr(self.hf_config, "image_grid_pinpoints")
and "anyres" in aspect_ratio
else None
)
if isinstance(image_data, list) and len(image_data) > 0:
pixel_values, image_hash, image_size = [], [], []
if len(image_data) > 1:
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
for img_data in image_data:
pixel_v, image_h, image_s = await self._process_single_image(
img_data, aspect_ratio, grid_pinpoints
)
pixel_values.append(pixel_v)
image_hash.append(image_h)
image_size.append(image_s)
pixel_values = np.stack(pixel_values, axis=0)
else:
pixel_values, image_hash, image_size = await self._process_single_image(
image_data[0], aspect_ratio, grid_pinpoints
)
image_hash = [image_hash]
image_size = [image_size]
elif isinstance(image_data, str):
pixel_values, image_hash, image_size = await self._process_single_image(
image_data, aspect_ratio, grid_pinpoints
)
image_hash = [image_hash]
image_size = [image_size]
else:
pixel_values, image_hash, image_size = None, None, None
return pixel_values, image_hash, image_size
async def _process_single_image(self, image_data, aspect_ratio, grid_pinpoints):
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
get_pixel_values,
image_data,
aspect_ratio,
grid_pinpoints,
)
else:
return get_pixel_values(
image_data, aspect_ratio, grid_pinpoints, self.processor
)
async def generate_request( async def generate_request(
self, obj: Union[GenerateReqInput, EmbeddingReqInput], request=None self, obj: Union[GenerateReqInput, EmbeddingReqInput], request=None
): ):
...@@ -198,7 +145,7 @@ class TokenizerManager: ...@@ -198,7 +145,7 @@ class TokenizerManager:
self.create_handle_loop() self.create_handle_loop()
while self.model_update_lock.locked(): while self.model_update_lock.locked():
await asyncio.sleep(0) await asyncio.sleep(0.001)
obj.post_init() obj.post_init()
is_single = obj.is_single is_single = obj.is_single
...@@ -214,8 +161,8 @@ class TokenizerManager: ...@@ -214,8 +161,8 @@ class TokenizerManager:
self, self,
obj: Union[GenerateReqInput, EmbeddingReqInput], obj: Union[GenerateReqInput, EmbeddingReqInput],
request, request,
index=None, index: Optional[int] = None,
is_cache_for_prefill=False, is_cache_for_prefill: Optional[bool] = False,
): ):
if not is_cache_for_prefill: # The normal case with a single prompt if not is_cache_for_prefill: # The normal case with a single prompt
not_use_index = index is None not_use_index = index is None
...@@ -235,7 +182,7 @@ class TokenizerManager: ...@@ -235,7 +182,7 @@ class TokenizerManager:
) )
if self.is_generation: if self.is_generation:
pixel_values, image_hash, image_size = await self.get_pixel_values( pixel_values, image_hash, image_size = await self._get_pixel_values(
obj.image_data obj.image_data
) )
return_logprob = ( return_logprob = (
...@@ -345,7 +292,7 @@ class TokenizerManager: ...@@ -345,7 +292,7 @@ class TokenizerManager:
parallel_sample_num = obj.parallel_sample_num parallel_sample_num = obj.parallel_sample_num
if parallel_sample_num != 1: if parallel_sample_num != 1:
# Send prefill requests to cache the common input # Send prefill requests to cache the common prefix
parallel_sample_num += 1 parallel_sample_num += 1
input_id_result = [] if obj.input_ids is None else None input_id_result = [] if obj.input_ids is None else None
for i in range(batch_size): for i in range(batch_size):
...@@ -436,7 +383,6 @@ class TokenizerManager: ...@@ -436,7 +383,6 @@ class TokenizerManager:
) )
# Then process the responses based on streaming option # Then process the responses based on streaming option
is_stream = hasattr(obj, "stream") and obj.stream is_stream = hasattr(obj, "stream") and obj.stream
tasks = [asyncio.create_task(gen.__anext__()) for gen in generators] tasks = [asyncio.create_task(gen.__anext__()) for gen in generators]
...@@ -482,9 +428,9 @@ class TokenizerManager: ...@@ -482,9 +428,9 @@ class TokenizerManager:
async def _get_pixel_values(self, image_data): async def _get_pixel_values(self, image_data):
if isinstance(image_data, list) and len(image_data) > 0: if isinstance(image_data, list) and len(image_data) > 0:
return await self.get_pixel_values(image_data[0]) return await self._get_pixel_values_internal(image_data[0])
elif isinstance(image_data, str): elif isinstance(image_data, str):
return await self.get_pixel_values(image_data) return await self._get_pixel_values_internal(image_data)
else: else:
return None, None, None return None, None, None
...@@ -563,6 +509,13 @@ class TokenizerManager: ...@@ -563,6 +509,13 @@ class TokenizerManager:
req = FlushCacheReq() req = FlushCacheReq()
self.send_to_router.send_pyobj(req) self.send_to_router.send_pyobj(req)
def abort_request(self, rid: str):
if rid not in self.rid_to_state:
return
del self.rid_to_state[rid]
req = AbortReq(rid)
self.send_to_router.send_pyobj(req)
async def update_weights(self, obj: UpdateWeightReqInput, request): async def update_weights(self, obj: UpdateWeightReqInput, request):
if self.to_create_loop: if self.to_create_loop:
self.create_handle_loop() self.create_handle_loop()
...@@ -587,13 +540,6 @@ class TokenizerManager: ...@@ -587,13 +540,6 @@ class TokenizerManager:
else: else:
return False, "Another update is in progress. Please try again later." return False, "Another update is in progress. Please try again later."
def abort_request(self, rid: str):
if rid not in self.rid_to_state:
return
del self.rid_to_state[rid]
req = AbortReq(rid)
self.send_to_router.send_pyobj(req)
def create_abort_task(self, obj: GenerateReqInput): def create_abort_task(self, obj: GenerateReqInput):
# Abort the request if the client is disconnected. # Abort the request if the client is disconnected.
async def abort_request(): async def abort_request():
...@@ -617,6 +563,8 @@ class TokenizerManager: ...@@ -617,6 +563,8 @@ class TokenizerManager:
loop.create_task(self.handle_loop()) loop.create_task(self.handle_loop())
async def handle_loop(self): async def handle_loop(self):
"""The event loop that handles requests"""
while True: while True:
recv_obj: Union[ recv_obj: Union[
BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightReqOutput BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightReqOutput
...@@ -713,11 +661,69 @@ class TokenizerManager: ...@@ -713,11 +661,69 @@ class TokenizerManager:
) )
return top_logprobs return top_logprobs
async def _get_pixel_values_internal(self, image_data, aspect_ratio=None):
aspect_ratio = (
getattr(self.hf_config, "image_aspect_ratio", None)
if aspect_ratio is None
else aspect_ratio
)
grid_pinpoints = (
self.hf_config.image_grid_pinpoints
if hasattr(self.hf_config, "image_grid_pinpoints")
and "anyres" in aspect_ratio
else None
)
if isinstance(image_data, list) and len(image_data) > 0:
pixel_values, image_hash, image_size = [], [], []
if len(image_data) > 1:
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
for img_data in image_data:
pixel_v, image_h, image_s = await self._process_single_image(
img_data, aspect_ratio, grid_pinpoints
)
pixel_values.append(pixel_v)
image_hash.append(image_h)
image_size.append(image_s)
pixel_values = np.stack(pixel_values, axis=0)
else:
pixel_values, image_hash, image_size = await self._process_single_image(
image_data[0], aspect_ratio, grid_pinpoints
)
image_hash = [image_hash]
image_size = [image_size]
elif isinstance(image_data, str):
pixel_values, image_hash, image_size = await self._process_single_image(
image_data, aspect_ratio, grid_pinpoints
)
image_hash = [image_hash]
image_size = [image_size]
else:
pixel_values, image_hash, image_size = None, None, None
return pixel_values, image_hash, image_size
async def _process_single_image(self, image_data, aspect_ratio, grid_pinpoints):
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
_process_single_image_task,
image_data,
aspect_ratio,
grid_pinpoints,
)
else:
return _process_single_image_task(
image_data, aspect_ratio, grid_pinpoints, self.processor
)
global global_processor global global_processor
def init_global_processor(server_args: ServerArgs): def init_global_processor(server_args: ServerArgs):
"""Init the global processor for multi modal models."""
global global_processor global global_processor
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
global_processor = get_processor( global_processor = get_processor(
...@@ -727,7 +733,7 @@ def init_global_processor(server_args: ServerArgs): ...@@ -727,7 +733,7 @@ def init_global_processor(server_args: ServerArgs):
) )
def get_pixel_values( def _process_single_image_task(
image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None
): ):
try: try:
...@@ -759,4 +765,4 @@ def get_pixel_values( ...@@ -759,4 +765,4 @@ def get_pixel_values(
pixel_values = pixel_values.astype(np.float16) pixel_values = pixel_values.astype(np.float16)
return pixel_values, image_hash, image.size return pixel_values, image_hash, image.size
except Exception: except Exception:
print("Exception in TokenizerManager:\n" + get_exception_traceback()) logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
...@@ -56,6 +56,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode ...@@ -56,6 +56,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
configure_logger,
is_multimodal_model, is_multimodal_model,
set_random_seed, set_random_seed,
suppress_other_loggers, suppress_other_loggers,
...@@ -145,7 +146,6 @@ class ModelTpServer: ...@@ -145,7 +146,6 @@ class ModelTpServer:
# Print info # Print info
logger.info( logger.info(
f"[gpu={self.gpu_id}] "
f"max_total_num_tokens={self.max_total_num_tokens}, " f"max_total_num_tokens={self.max_total_num_tokens}, "
f"max_prefill_tokens={self.max_prefill_tokens}, " f"max_prefill_tokens={self.max_prefill_tokens}, "
f"max_running_requests={self.max_running_requests}, " f"max_running_requests={self.max_running_requests}, "
...@@ -284,7 +284,7 @@ class ModelTpServer: ...@@ -284,7 +284,7 @@ class ModelTpServer:
self.num_generated_tokens = 0 self.num_generated_tokens = 0
self.last_stats_tic = time.time() self.last_stats_tic = time.time()
logger.info( logger.info(
f"[gpu={self.gpu_id}] Decode batch. " f"Decode batch. "
f"#running-req: {len(self.running_batch.reqs)}, " f"#running-req: {len(self.running_batch.reqs)}, "
f"#token: {num_used}, " f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, " f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
...@@ -443,7 +443,7 @@ class ModelTpServer: ...@@ -443,7 +443,7 @@ class ModelTpServer:
if num_mixed_running > 0: if num_mixed_running > 0:
logger.info( logger.info(
f"[gpu={self.gpu_id}] Prefill batch" f"Prefill batch"
f"(mixed #running-req: {num_mixed_running}). " f"(mixed #running-req: {num_mixed_running}). "
f"#new-seq: {len(can_run_list)}, " f"#new-seq: {len(can_run_list)}, "
f"#new-token: {adder.log_input_tokens}, " f"#new-token: {adder.log_input_tokens}, "
...@@ -453,7 +453,7 @@ class ModelTpServer: ...@@ -453,7 +453,7 @@ class ModelTpServer:
) )
else: else:
logger.info( logger.info(
f"[gpu={self.gpu_id}] Prefill batch. " f"Prefill batch. "
f"#new-seq: {len(can_run_list)}, " f"#new-seq: {len(can_run_list)}, "
f"#new-token: {adder.log_input_tokens}, " f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, " f"#cached-token: {adder.log_hit_tokens}, "
...@@ -631,7 +631,7 @@ class ModelTpServer: ...@@ -631,7 +631,7 @@ class ModelTpServer:
self.new_token_ratio = new_token_ratio self.new_token_ratio = new_token_ratio
logger.info( logger.info(
"decode out of memory happened, " "Decode out of memory happened. "
f"#retracted_reqs: {len(retracted_reqs)}, " f"#retracted_reqs: {len(retracted_reqs)}, "
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}" f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
) )
...@@ -848,7 +848,9 @@ def run_tp_server( ...@@ -848,7 +848,9 @@ def run_tp_server(
nccl_port: int, nccl_port: int,
model_overide_args: dict, model_overide_args: dict,
): ):
"""Run a tensor parallel server.""" """Run a tensor parallel model server."""
configure_logger(server_args, prefix=f" TP{tp_rank}")
try: try:
model_server = ModelTpServer( model_server = ModelTpServer(
gpu_id, gpu_id,
......
...@@ -109,7 +109,7 @@ class ModelRunner: ...@@ -109,7 +109,7 @@ class ModelRunner:
def init_torch_distributed(self): def init_torch_distributed(self):
# Init torch distributed # Init torch distributed
torch.cuda.set_device(self.gpu_id) torch.cuda.set_device(self.gpu_id)
logger.info(f"[gpu={self.gpu_id}] Init nccl begin.") logger.info("Init nccl begin.")
if not self.server_args.enable_p2p_check: if not self.server_args.enable_p2p_check:
monkey_patch_vllm_p2p_access_check(self.gpu_id) monkey_patch_vllm_p2p_access_check(self.gpu_id)
...@@ -152,8 +152,7 @@ class ModelRunner: ...@@ -152,8 +152,7 @@ class ModelRunner:
def load_model(self): def load_model(self):
logger.info( logger.info(
f"[gpu={self.gpu_id}] Load weight begin. " f"Load weight begin. avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
) )
if torch.cuda.get_device_capability()[0] < 8: if torch.cuda.get_device_capability()[0] < 8:
logger.info( logger.info(
...@@ -208,7 +207,7 @@ class ModelRunner: ...@@ -208,7 +207,7 @@ class ModelRunner:
) )
logger.info( logger.info(
f"[gpu={self.gpu_id}] Load weight end. " f"Load weight end. "
f"type={type(self.model).__name__}, " f"type={type(self.model).__name__}, "
f"dtype={self.dtype}, " f"dtype={self.dtype}, "
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
...@@ -224,7 +223,7 @@ class ModelRunner: ...@@ -224,7 +223,7 @@ class ModelRunner:
from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.model_loader.utils import set_default_torch_dtype
logger.info( logger.info(
f"[gpu={self.gpu_id}] Update weights begin. " f"Update weights begin. "
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
) )
...@@ -298,7 +297,7 @@ class ModelRunner: ...@@ -298,7 +297,7 @@ class ModelRunner:
self.load_config = load_config self.load_config = load_config
self.model_config.path = model_path self.model_config.path = model_path
logger.info(f"[gpu={self.gpu_id}] Update weights end.") logger.info("Update weights end.")
return True, "Succeeded to update model weights" return True, "Succeeded to update model weights"
def profile_max_num_token(self, total_gpu_memory: int): def profile_max_num_token(self, total_gpu_memory: int):
...@@ -387,7 +386,7 @@ class ModelRunner: ...@@ -387,7 +386,7 @@ class ModelRunner:
layer_num=self.model_config.num_hidden_layers, layer_num=self.model_config.num_hidden_layers,
) )
logger.info( logger.info(
f"[gpu={self.gpu_id}] Memory pool end. " f"Memory pool end. "
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
) )
...@@ -473,9 +472,7 @@ class ModelRunner: ...@@ -473,9 +472,7 @@ class ModelRunner:
self.cuda_graph_runner = None self.cuda_graph_runner = None
return return
logger.info( logger.info("Capture cuda graph begin. This can take up to several minutes.")
f"[gpu={self.gpu_id}] Capture cuda graph begin. This can take up to several minutes."
)
if self.server_args.disable_cuda_graph_padding: if self.server_args.disable_cuda_graph_padding:
batch_size_list = list(range(1, 32)) + [64, 128] batch_size_list = list(range(1, 32)) + [64, 128]
......
...@@ -123,7 +123,7 @@ def create_streaming_error_response( ...@@ -123,7 +123,7 @@ def create_streaming_error_response(
def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg): def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg):
global chat_template_name global chat_template_name
print(f"Use chat template: {chat_template_arg}") logger.info(f"Use chat template: {chat_template_arg}")
if not chat_template_exists(chat_template_arg): if not chat_template_exists(chat_template_arg):
if not os.path.exists(chat_template_arg): if not os.path.exists(chat_template_arg):
raise RuntimeError( raise RuntimeError(
...@@ -355,7 +355,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe ...@@ -355,7 +355,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
} }
except Exception as e: except Exception as e:
print("error in SGLang:", e) logger.error("error in SGLang:", e)
# Update batch status to "failed" # Update batch status to "failed"
retrieve_batch = batch_storage[batch_id] retrieve_batch = batch_storage[batch_id]
retrieve_batch.status = "failed" retrieve_batch.status = "failed"
......
...@@ -74,6 +74,7 @@ from sglang.srt.utils import ( ...@@ -74,6 +74,7 @@ from sglang.srt.utils import (
add_api_key_middleware, add_api_key_middleware,
allocate_init_ports, allocate_init_ports,
assert_pkg_version, assert_pkg_version,
configure_logger,
enable_show_time_cost, enable_show_time_cost,
kill_child_process, kill_child_process,
maybe_set_triton_cache_manager, maybe_set_triton_cache_manager,
...@@ -270,15 +271,12 @@ def launch_server( ...@@ -270,15 +271,12 @@ def launch_server(
"""Launch an HTTP server.""" """Launch an HTTP server."""
global tokenizer_manager global tokenizer_manager
logging.basicConfig( configure_logger(server_args)
level=getattr(logging, server_args.log_level.upper()),
format="%(message)s",
)
server_args.check_server_args() server_args.check_server_args()
_set_envs_and_config(server_args) _set_envs_and_config(server_args)
# Allocate ports # Allocate ports for inter-process communications
server_args.port, server_args.additional_ports = allocate_init_ports( server_args.port, server_args.additional_ports = allocate_init_ports(
server_args.port, server_args.port,
server_args.additional_ports, server_args.additional_ports,
......
...@@ -418,7 +418,7 @@ class ServerArgs: ...@@ -418,7 +418,7 @@ class ServerArgs:
parser.add_argument( parser.add_argument(
"--enable-mixed-chunk", "--enable-mixed-chunk",
action="store_true", action="store_true",
help="Enabling mixing prefill and decode in a chunked batch.", help="Enabling mixing prefill and decode in a batch when using chunked prefill.",
) )
parser.add_argument( parser.add_argument(
"--enable-torch-compile", "--enable-torch-compile",
......
...@@ -692,7 +692,7 @@ def monkey_patch_vllm_qvk_linear_loader(): ...@@ -692,7 +692,7 @@ def monkey_patch_vllm_qvk_linear_loader():
setattr(QKVParallelLinear, "weight_loader", weight_loader_srt) setattr(QKVParallelLinear, "weight_loader", weight_loader_srt)
def add_api_key_middleware(app, api_key): def add_api_key_middleware(app, api_key: str):
@app.middleware("http") @app.middleware("http")
async def authentication(request, call_next): async def authentication(request, call_next):
if request.method == "OPTIONS": if request.method == "OPTIONS":
...@@ -704,7 +704,7 @@ def add_api_key_middleware(app, api_key): ...@@ -704,7 +704,7 @@ def add_api_key_middleware(app, api_key):
return await call_next(request) return await call_next(request)
def prepare_model(model_path): def prepare_model(model_path: str):
if "SGLANG_USE_MODELSCOPE" in os.environ: if "SGLANG_USE_MODELSCOPE" in os.environ:
if not os.path.exists(model_path): if not os.path.exists(model_path):
from modelscope import snapshot_download from modelscope import snapshot_download
...@@ -713,7 +713,7 @@ def prepare_model(model_path): ...@@ -713,7 +713,7 @@ def prepare_model(model_path):
return model_path return model_path
def prepare_tokenizer(tokenizer_path): def prepare_tokenizer(tokenizer_path: str):
if "SGLANG_USE_MODELSCOPE" in os.environ: if "SGLANG_USE_MODELSCOPE" in os.environ:
if not os.path.exists(tokenizer_path): if not os.path.exists(tokenizer_path):
from modelscope import snapshot_download from modelscope import snapshot_download
...@@ -722,3 +722,13 @@ def prepare_tokenizer(tokenizer_path): ...@@ -722,3 +722,13 @@ def prepare_tokenizer(tokenizer_path):
tokenizer_path, ignore_patterns=["*.bin", "*.safetensors"] tokenizer_path, ignore_patterns=["*.bin", "*.safetensors"]
) )
return tokenizer_path return tokenizer_path
def configure_logger(server_args, prefix: str = ""):
format = f"[%(asctime)s{prefix}] %(message)s"
logging.basicConfig(
level=getattr(logging, server_args.log_level.upper()),
format=format,
datefmt="%H:%M:%S",
force=True,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment