Unverified Commit 20fd53b8 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Correctly abort the failed grammar requests & Improve the handling of abort (#6803)

parent 6a47b730
...@@ -60,7 +60,7 @@ class BaseGrammarObject: ...@@ -60,7 +60,7 @@ class BaseGrammarObject:
raise NotImplementedError() raise NotImplementedError()
def copy(self) -> "BaseGrammarObject": def copy(self) -> "BaseGrammarObject":
raise NotImplementedError() return self
@property @property
def finished(self): def finished(self):
...@@ -99,9 +99,12 @@ class BaseGrammarObject: ...@@ -99,9 +99,12 @@ class BaseGrammarObject:
raise NotImplementedError() raise NotImplementedError()
INVALID_GRAMMAR_OBJ = BaseGrammarObject()
@dataclass @dataclass
class CacheEntry: class CacheEntry:
value: Optional[BaseGrammarObject] value: BaseGrammarObject
event: Event event: Event
......
...@@ -28,6 +28,7 @@ from llguidance.torch import ( ...@@ -28,6 +28,7 @@ from llguidance.torch import (
) )
from sglang.srt.constrained.base_grammar_backend import ( from sglang.srt.constrained.base_grammar_backend import (
INVALID_GRAMMAR_OBJ,
BaseGrammarBackend, BaseGrammarBackend,
BaseGrammarObject, BaseGrammarObject,
) )
...@@ -126,8 +127,8 @@ class GuidanceBackend(BaseGrammarBackend): ...@@ -126,8 +127,8 @@ class GuidanceBackend(BaseGrammarBackend):
serialized_grammar=serialized_grammar, serialized_grammar=serialized_grammar,
) )
except Exception as e: except Exception as e:
logger.warning(f"Skip invalid grammar: {serialized_grammar}, {e=}") logger.error(f"Hit invalid grammar: {serialized_grammar=}, {e=}")
return None return INVALID_GRAMMAR_OBJ
def dispatch_json(self, key_string: str) -> Optional[GuidanceGrammar]: def dispatch_json(self, key_string: str) -> Optional[GuidanceGrammar]:
try: try:
...@@ -138,8 +139,8 @@ class GuidanceBackend(BaseGrammarBackend): ...@@ -138,8 +139,8 @@ class GuidanceBackend(BaseGrammarBackend):
}, },
) )
except Exception as e: except Exception as e:
logger.warning(f"Skip invalid grammar: {key_string=}, {e=}") logger.error(f"Hit invalid json_schema: {key_string=}, {e=}")
return None return INVALID_GRAMMAR_OBJ
return self._from_serialized(serialized_grammar) return self._from_serialized(serialized_grammar)
def dispatch_regex(self, key_string: str) -> Optional[GuidanceGrammar]: def dispatch_regex(self, key_string: str) -> Optional[GuidanceGrammar]:
...@@ -151,8 +152,8 @@ class GuidanceBackend(BaseGrammarBackend): ...@@ -151,8 +152,8 @@ class GuidanceBackend(BaseGrammarBackend):
serialized_grammar = grammar_from("ebnf", key_string) serialized_grammar = grammar_from("ebnf", key_string)
return self._from_serialized(serialized_grammar) return self._from_serialized(serialized_grammar)
except ValueError as e: except ValueError as e:
logger.warning(f"Skip invalid ebnf: regex={key_string}, {e=}") logger.error(f"Hit invalid ebnf: {key_string=}, {e=}")
return None return INVALID_GRAMMAR_OBJ
def dispatch_structural_tag(self, key_string: str) -> Optional[GuidanceGrammar]: def dispatch_structural_tag(self, key_string: str) -> Optional[GuidanceGrammar]:
try: try:
...@@ -169,5 +170,5 @@ class GuidanceBackend(BaseGrammarBackend): ...@@ -169,5 +170,5 @@ class GuidanceBackend(BaseGrammarBackend):
g = StructTag.to_grammar(tags) g = StructTag.to_grammar(tags)
return self._from_serialized(g) return self._from_serialized(g)
except Exception as e: except Exception as e:
logging.warning(f"Skip invalid structural_tag: {key_string}, {e=}") logging.error(f"Hit invalid structural_tag: {key_string=}, {e=}")
return None return INVALID_GRAMMAR_OBJ
...@@ -24,6 +24,7 @@ from outlines.models.transformers import TransformerTokenizer ...@@ -24,6 +24,7 @@ from outlines.models.transformers import TransformerTokenizer
from pydantic import BaseModel from pydantic import BaseModel
from sglang.srt.constrained.base_grammar_backend import ( from sglang.srt.constrained.base_grammar_backend import (
INVALID_GRAMMAR_OBJ,
BaseGrammarBackend, BaseGrammarBackend,
BaseGrammarObject, BaseGrammarObject,
) )
...@@ -151,8 +152,8 @@ class OutlinesGrammarBackend(BaseGrammarBackend): ...@@ -151,8 +152,8 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
# outlines <= 0.0.46 # outlines <= 0.0.46
guide = RegexGuide(regex, self.outlines_tokenizer) guide = RegexGuide(regex, self.outlines_tokenizer)
except interegular.patterns.InvalidSyntax as e: except interegular.patterns.InvalidSyntax as e:
logger.warning(f"skip invalid regex schema: {regex=}, {e=}") logger.error(f"Hit invalid regex schema: {regex=}, {e=}")
return None return INVALID_GRAMMAR_OBJ
jump_forward_map = None jump_forward_map = None
return OutlinesGrammar(guide, jump_forward_map) return OutlinesGrammar(guide, jump_forward_map)
...@@ -170,8 +171,8 @@ class OutlinesGrammarBackend(BaseGrammarBackend): ...@@ -170,8 +171,8 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
whitespace_pattern=self.whitespace_pattern, whitespace_pattern=self.whitespace_pattern,
) )
except (NotImplementedError, json.decoder.JSONDecodeError, ValueError) as e: except (NotImplementedError, json.decoder.JSONDecodeError, ValueError) as e:
logger.warning(f"Skip invalid json_schema: {key_string=}, {e=}") logger.error(f"Hit invalid json_schema: {key_string=}, {e=}")
return None return INVALID_GRAMMAR_OBJ
return self._compile_regex(regex) return self._compile_regex(regex)
def dispatch_regex(self, key_string: str): def dispatch_regex(self, key_string: str):
......
...@@ -28,6 +28,7 @@ from xgrammar import ( ...@@ -28,6 +28,7 @@ from xgrammar import (
) )
from sglang.srt.constrained.base_grammar_backend import ( from sglang.srt.constrained.base_grammar_backend import (
INVALID_GRAMMAR_OBJ,
BaseGrammarBackend, BaseGrammarBackend,
BaseGrammarObject, BaseGrammarObject,
) )
...@@ -152,6 +153,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend): ...@@ -152,6 +153,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
): ):
super().__init__() super().__init__()
if True:
tokenizer_info = TokenizerInfo.from_huggingface( tokenizer_info = TokenizerInfo.from_huggingface(
tokenizer, vocab_size=vocab_size tokenizer, vocab_size=vocab_size
) )
...@@ -178,25 +180,26 @@ class XGrammarGrammarBackend(BaseGrammarBackend): ...@@ -178,25 +180,26 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
ctx = self.grammar_compiler.compile_builtin_json_grammar() ctx = self.grammar_compiler.compile_builtin_json_grammar()
else: else:
ctx = self.grammar_compiler.compile_json_schema(schema=key_string) ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
except RuntimeError as e:
logging.warning(f"Skip invalid json_schema: json_schema={key_string}, {e=}") except (RuntimeError, json.decoder.JSONDecodeError) as e:
return None logging.error(f"Hit invalid json_schema: {key_string=}, {e=}")
return INVALID_GRAMMAR_OBJ
return self._from_context(ctx, key_string) return self._from_context(ctx, key_string)
def dispatch_ebnf(self, key_string: str) -> Optional[XGrammarGrammar]: def dispatch_ebnf(self, key_string: str) -> Optional[XGrammarGrammar]:
try: try:
ctx = self.grammar_compiler.compile_grammar(key_string) ctx = self.grammar_compiler.compile_grammar(key_string)
except RuntimeError as e: except RuntimeError as e:
logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}") logging.error(f"Hit invalid ebnf: {key_string=}, {e=}")
return None return INVALID_GRAMMAR_OBJ
return self._from_context(ctx, key_string) return self._from_context(ctx, key_string)
def dispatch_regex(self, key_string: str) -> Optional[XGrammarGrammar]: def dispatch_regex(self, key_string: str) -> Optional[XGrammarGrammar]:
try: try:
ctx = self.grammar_compiler.compile_regex(key_string) ctx = self.grammar_compiler.compile_regex(key_string)
except RuntimeError as e: except RuntimeError as e:
logging.warning(f"Skip invalid regex: regex={key_string}, {e=}") logging.error(f"Hit invalid regex: {key_string=}, {e=}")
return None return INVALID_GRAMMAR_OBJ
return self._from_context(ctx, key_string) return self._from_context(ctx, key_string)
def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]: def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]:
...@@ -213,13 +216,10 @@ class XGrammarGrammarBackend(BaseGrammarBackend): ...@@ -213,13 +216,10 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
ctx = self.grammar_compiler.compile_structural_tag( ctx = self.grammar_compiler.compile_structural_tag(
tags, structural_tag["triggers"] tags, structural_tag["triggers"]
) )
except RuntimeError as e: except (RuntimeError, json.decoder.JSONDecodeError) as e:
logging.warning( logging.error(f"Hit invalid structural_tag: {key_string=}, {e=}")
f"Skip invalid structural_tag: structural_tag={key_string}, {e=}" return INVALID_GRAMMAR_OBJ
)
return None
return self._from_context(ctx, key_string) return self._from_context(ctx, key_string)
def reset(self): def reset(self):
if self.grammar_compiler:
self.grammar_compiler.clear_cache() self.grammar_compiler.clear_cache()
...@@ -256,7 +256,7 @@ async def generate_request(obj: GenerateReqInput, request: Request): ...@@ -256,7 +256,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
) + b"\n\n" ) + b"\n\n"
except ValueError as e: except ValueError as e:
out = {"error": {"message": str(e)}} out = {"error": {"message": str(e)}}
logger.error(f"Error: {e}") logger.error(f"[http_server] Error: {e}")
yield b"data: " + orjson.dumps( yield b"data: " + orjson.dumps(
out, option=orjson.OPT_NON_STR_KEYS out, option=orjson.OPT_NON_STR_KEYS
) + b"\n\n" ) + b"\n\n"
...@@ -274,7 +274,7 @@ async def generate_request(obj: GenerateReqInput, request: Request): ...@@ -274,7 +274,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
).__anext__() ).__anext__()
return ret return ret
except ValueError as e: except ValueError as e:
logger.error(f"Error: {e}") logger.error(f"[http_server] Error: {e}")
return _create_error_response(e) return _create_error_response(e)
......
...@@ -37,6 +37,7 @@ import hashlib ...@@ -37,6 +37,7 @@ import hashlib
import logging import logging
import threading import threading
from enum import Enum, auto from enum import Enum, auto
from http import HTTPStatus
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
import numpy as np import numpy as np
...@@ -51,6 +52,7 @@ from sglang.srt.disaggregation.base import BaseKVSender ...@@ -51,6 +52,7 @@ from sglang.srt.disaggregation.base import BaseKVSender
from sglang.srt.disaggregation.decode_schedule_batch_mixin import ( from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
ScheduleBatchDisaggregationDecodeMixin, ScheduleBatchDisaggregationDecodeMixin,
) )
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
from sglang.srt.layers.multimodal import gpu_tensor_hash from sglang.srt.layers.multimodal import gpu_tensor_hash
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.chunk_cache import ChunkCache
...@@ -60,7 +62,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw ...@@ -60,7 +62,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import flatten_nested_list, get_compiler_backend, support_triton from sglang.srt.utils import flatten_nested_list, support_triton
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
...@@ -771,6 +773,16 @@ class Req: ...@@ -771,6 +773,16 @@ class Req:
logger.info(f"{prefix}: {self.time_stats}") logger.info(f"{prefix}: {self.time_stats}")
self.has_log_time_stats = True self.has_log_time_stats = True
def set_finish_with_abort(self, error_msg: str):
if get_tensor_model_parallel_rank() == 0:
logger.error(f"{error_msg}, {self.rid=}")
self.multimodal_inputs = None
self.grammar = None
self.origin_input_ids = [0] # set it to one token to skip the long prefill
self.finished_reason = FINISH_ABORT(
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
)
def __repr__(self): def __repr__(self):
return ( return (
f"Req(rid={self.rid}, " f"Req(rid={self.rid}, "
......
...@@ -35,7 +35,10 @@ from torch.distributed import barrier ...@@ -35,7 +35,10 @@ from torch.distributed import barrier
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend from sglang.srt.constrained.base_grammar_backend import (
INVALID_GRAMMAR_OBJ,
create_grammar_backend,
)
from sglang.srt.disaggregation.decode import ( from sglang.srt.disaggregation.decode import (
DecodePreallocQueue, DecodePreallocQueue,
DecodeTransferQueue, DecodeTransferQueue,
...@@ -949,12 +952,12 @@ class Scheduler( ...@@ -949,12 +952,12 @@ class Scheduler(
if self.disaggregation_mode != DisaggregationMode.NULL: if self.disaggregation_mode != DisaggregationMode.NULL:
# Invalid request for disaggregated mode # Invalid request for disaggregated mode
if recv_req.bootstrap_room is None: if recv_req.bootstrap_room is None:
error_message = ( error_msg = (
f"Invalid request: Disaggregated request received without " f"Invalid request: Disaggregated request received without "
f"boostrap room id. {req.rid=}" f"boostrap room id. {req.rid=}"
) )
logger.error(error_message) logger.error(error_msg)
prepare_abort(req, error_message) prepare_abort(req, error_msg)
self.stream_output([req], req.return_logprob) self.stream_output([req], req.return_logprob)
return return
...@@ -985,29 +988,23 @@ class Scheduler( ...@@ -985,29 +988,23 @@ class Scheduler(
req.extend_image_inputs(image_inputs) req.extend_image_inputs(image_inputs)
if len(req.origin_input_ids) >= self.max_req_input_len: if len(req.origin_input_ids) >= self.max_req_input_len:
error_msg = ( req.set_finish_with_abort(
error_msg=(
"Multimodal prompt is too long after expanding multimodal tokens. " "Multimodal prompt is too long after expanding multimodal tokens. "
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}." f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
) )
logger.error(error_msg)
req.origin_input_ids = [0]
req.multimodal_inputs = None
req.sampling_params.max_new_tokens = 0
req.finished_reason = FINISH_ABORT(
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
) )
self._add_request_to_queue(req) self._add_request_to_queue(req)
return return
# Validate prompts length # Validate prompt length
error_msg = validate_input_length( error_msg = validate_input_length(
req, req,
self.max_req_input_len, self.max_req_input_len,
self.server_args.allow_auto_truncate, self.server_args.allow_auto_truncate,
) )
if error_msg: if error_msg:
req.origin_input_ids = [0] req.set_finish_with_abort(error_msg)
req.sampling_params.max_new_tokens = 0
self._add_request_to_queue(req) self._add_request_to_queue(req)
return return
...@@ -1019,12 +1016,9 @@ class Scheduler( ...@@ -1019,12 +1016,9 @@ class Scheduler(
req.logprob_start_len = recv_req.logprob_start_len req.logprob_start_len = recv_req.logprob_start_len
if req.logprob_start_len >= len(req.origin_input_ids): if req.logprob_start_len >= len(req.origin_input_ids):
req.finished_reason = FINISH_ABORT( error_msg = f"{req.logprob_start_len=} is higher than the number of input tokens {len(req.origin_input_ids)=}. Please use a smaller logprob_start_len."
f"logprob_start_len, ({req.logprob_start_len}) is higher than the number of input tokens ({len(req.origin_input_ids)}). Request with a lower logprob_start_len.",
HTTPStatus.BAD_REQUEST,
"BadRequestError",
)
req.logprob_start_len = len(req.origin_input_ids) - 1 req.logprob_start_len = len(req.origin_input_ids) - 1
req.set_finish_with_abort(error_msg)
self._add_request_to_queue(req) self._add_request_to_queue(req)
return return
...@@ -1061,6 +1055,10 @@ class Scheduler( ...@@ -1061,6 +1055,10 @@ class Scheduler(
if not cache_hit: if not cache_hit:
req.grammar_key = key req.grammar_key = key
add_to_grammar_queue = True add_to_grammar_queue = True
else:
if value is INVALID_GRAMMAR_OBJ: # We hit a cached invalid grammar.
error_msg = f"Invalid grammar request with cache hit: {key=}"
req.set_finish_with_abort(error_msg)
if add_to_grammar_queue: if add_to_grammar_queue:
req.queue_time_start = time.perf_counter() req.queue_time_start = time.perf_counter()
...@@ -1108,19 +1106,13 @@ class Scheduler( ...@@ -1108,19 +1106,13 @@ class Scheduler(
req.extend_image_inputs(image_inputs) req.extend_image_inputs(image_inputs)
if len(req.origin_input_ids) >= self.max_req_input_len: if len(req.origin_input_ids) >= self.max_req_input_len:
error_msg = ( req.set_finish_with_abort(
error_msg=(
"Multimodal prompt is too long after expanding multimodal tokens. " "Multimodal prompt is too long after expanding multimodal tokens. "
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}." f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
) )
logger.error(error_msg)
req.origin_input_ids = [0]
req.multimodal_inputs = None
req.sampling_params.max_new_tokens = 0
req.finished_reason = FINISH_ABORT(
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
) )
req.queue_time_start = time.perf_counter() self._add_request_to_queue(req)
self.waiting_queue.append(req)
return return
# Validate prompts length # Validate prompts length
...@@ -1785,17 +1777,25 @@ class Scheduler( ...@@ -1785,17 +1777,25 @@ class Scheduler(
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue.""" """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
num_ready_reqs = 0 num_ready_reqs = 0
num_abort_reqs = 0 num_timeout_reqs = 0
for req in self.grammar_queue: for req in self.grammar_queue:
try: try:
if req.finished(): # It is aborted by AbortReq
num_ready_reqs += 1
continue
req.grammar = req.grammar.result(timeout=0.03) req.grammar = req.grammar.result(timeout=0.03)
if req.grammar:
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy()) self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
if req.grammar is INVALID_GRAMMAR_OBJ:
req.set_finish_with_abort(
f"Invalid grammar request: {req.grammar_key=}"
)
num_ready_reqs += 1 num_ready_reqs += 1
except futures._base.TimeoutError: except futures._base.TimeoutError:
req.grammar_wait_ct += 1 req.grammar_wait_ct += 1
# NOTE(lianmin): this timeout is the waiting time of the above line. It is
# not the waiting time from it enters the grammar queue.
if req.grammar_wait_ct > GRAMMAR_TIMEOUT / 0.03: if req.grammar_wait_ct > GRAMMAR_TIMEOUT / 0.03:
num_abort_reqs = 1 num_timeout_reqs = 1
break break
if self.server_args.enable_dp_attention: if self.server_args.enable_dp_attention:
...@@ -1807,28 +1807,33 @@ class Scheduler( ...@@ -1807,28 +1807,33 @@ class Scheduler(
if tp_size > 1: if tp_size > 1:
# Sync across TP ranks to make sure they have the same number of ready requests # Sync across TP ranks to make sure they have the same number of ready requests
tensor = torch.tensor([num_ready_reqs, num_abort_reqs], dtype=torch.int32) tensor = torch.tensor([num_ready_reqs, num_timeout_reqs], dtype=torch.int32)
torch.distributed.all_reduce( torch.distributed.all_reduce(
tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
) )
num_ready_reqs_max, num_abort_reqs_max = tensor.tolist() num_ready_reqs_max, num_timeout_reqs_max = tensor.tolist()
for i in range(num_ready_reqs, num_ready_reqs_max): for i in range(num_ready_reqs, num_ready_reqs_max):
req = self.grammar_queue[i] req = self.grammar_queue[i]
if req.finished(): # It is aborted by AbortReq
continue
req.grammar = req.grammar.result() req.grammar = req.grammar.result()
if req.grammar:
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy()) self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
if req.grammar is INVALID_GRAMMAR_OBJ:
req.set_finish_with_abort(
f"Invalid grammar request: {req.grammar_key=}"
)
else:
num_ready_reqs_max = num_ready_reqs
num_timeout_reqs_max = num_timeout_reqs
for i in range(num_ready_reqs, num_ready_reqs + num_abort_reqs_max): for i in range(num_ready_reqs, num_ready_reqs + num_timeout_reqs_max):
req = self.grammar_queue[i] req = self.grammar_queue[i]
req.grammar.cancel() req.grammar.cancel()
req.grammar = None
error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}" error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
logger.error(error_msg) req.set_finish_with_abort(error_msg)
req.finished_reason = FINISH_ABORT( self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ)
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError" num_ready_reqs = num_ready_reqs_max + num_timeout_reqs_max
)
num_ready_reqs = num_ready_reqs_max + num_abort_reqs_max
self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs]) self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
self.grammar_queue = self.grammar_queue[num_ready_reqs:] self.grammar_queue = self.grammar_queue[num_ready_reqs:]
...@@ -2024,8 +2029,6 @@ class Scheduler( ...@@ -2024,8 +2029,6 @@ class Scheduler(
) )
def abort_request(self, recv_req: AbortReq): def abort_request(self, recv_req: AbortReq):
# TODO(lmzheng): abort the requests in the grammar queue.
# Delete requests in the waiting queue # Delete requests in the waiting queue
to_del = [] to_del = []
for i, req in enumerate(self.waiting_queue): for i, req in enumerate(self.waiting_queue):
...@@ -2047,8 +2050,16 @@ class Scheduler( ...@@ -2047,8 +2050,16 @@ class Scheduler(
for req in reqs: for req in reqs:
if req.rid.startswith(recv_req.rid) and not req.finished(): if req.rid.startswith(recv_req.rid) and not req.finished():
logger.debug(f"Abort running request. {req.rid=}") logger.debug(f"Abort running request. {req.rid=}")
# We must use to_abort because it is in a running batch
req.to_abort = True req.to_abort = True
# Delete the requests in the grammar queue
for req in self.grammar_queue:
if req.rid.startswith(recv_req.rid):
logger.debug(f"Abort grammar queue request. {req.rid=}")
req.grammar.cancel()
req.set_finish_with_abort("Aborted by AbortReq.")
def _pause_engine(self) -> Tuple[List[Req], int]: def _pause_engine(self) -> Tuple[List[Req], int]:
raise NotImplementedError() raise NotImplementedError()
......
...@@ -221,7 +221,7 @@ class TokenizerManager: ...@@ -221,7 +221,7 @@ class TokenizerManager:
self.tokenizer = get_tokenizer_from_processor(self.processor) self.tokenizer = get_tokenizer_from_processor(self.processor)
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
else: else:
self.mm_processor = get_dummy_processor() self.mm_processor = None
if server_args.skip_tokenizer_init: if server_args.skip_tokenizer_init:
self.tokenizer = self.processor = None self.tokenizer = self.processor = None
...@@ -425,8 +425,8 @@ class TokenizerManager: ...@@ -425,8 +425,8 @@ class TokenizerManager:
is_single = obj.is_single is_single = obj.is_single
if is_single: if is_single:
tokenized_obj = await self._tokenize_one_request(obj) tokenized_obj = await self._tokenize_one_request(obj)
self._send_one_request(obj, tokenized_obj, created_time) state = self._send_one_request(obj, tokenized_obj, created_time)
async for response in self._wait_one_response(obj, request): async for response in self._wait_one_response(obj, state, request):
yield response yield response
else: else:
async for response in self._handle_batch_request( async for response in self._handle_batch_request(
...@@ -462,8 +462,7 @@ class TokenizerManager: ...@@ -462,8 +462,7 @@ class TokenizerManager:
) )
input_ids = self.tokenizer.encode(input_text) input_ids = self.tokenizer.encode(input_text)
image_inputs: Optional[Dict] = None if self.mm_processor and obj.contains_mm_input():
if obj.contains_mm_input():
image_inputs = await self.mm_processor.process_mm_data_async( image_inputs = await self.mm_processor.process_mm_data_async(
image_data=obj.image_data, image_data=obj.image_data,
input_text=input_text or input_ids, input_text=input_text or input_ids,
...@@ -472,6 +471,8 @@ class TokenizerManager: ...@@ -472,6 +471,8 @@ class TokenizerManager:
) )
if image_inputs and "input_ids" in image_inputs: if image_inputs and "input_ids" in image_inputs:
input_ids = image_inputs["input_ids"] input_ids = image_inputs["input_ids"]
else:
image_inputs: Optional[Dict] = None
self._validate_token_len(obj, input_ids) self._validate_token_len(obj, input_ids)
return self._create_tokenized_object( return self._create_tokenized_object(
...@@ -631,15 +632,15 @@ class TokenizerManager: ...@@ -631,15 +632,15 @@ class TokenizerManager:
self.send_to_scheduler.send_pyobj(tokenized_obj) self.send_to_scheduler.send_pyobj(tokenized_obj)
state = ReqState([], False, asyncio.Event(), obj, created_time=created_time) state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
self.rid_to_state[obj.rid] = state self.rid_to_state[obj.rid] = state
return state
async def _wait_one_response( async def _wait_one_response(
self, self,
obj: Union[GenerateReqInput, EmbeddingReqInput], obj: Union[GenerateReqInput, EmbeddingReqInput],
state: ReqState,
request: Optional[fastapi.Request] = None, request: Optional[fastapi.Request] = None,
): ):
"""Wait for the response of one request.""" """Wait for the response of one request."""
state = self.rid_to_state[obj.rid]
while True: while True:
try: try:
await asyncio.wait_for(state.event.wait(), timeout=4) await asyncio.wait_for(state.event.wait(), timeout=4)
...@@ -709,16 +710,16 @@ class TokenizerManager: ...@@ -709,16 +710,16 @@ class TokenizerManager:
for i, tokenized_obj in enumerate(tokenized_objs): for i, tokenized_obj in enumerate(tokenized_objs):
tmp_obj = obj[i] tmp_obj = obj[i]
self._send_one_request(tmp_obj, tokenized_obj, created_time) state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
generators.append(self._wait_one_response(tmp_obj, request)) generators.append(self._wait_one_response(tmp_obj, state, request))
rids.append(tmp_obj.rid) rids.append(tmp_obj.rid)
else: else:
# Sequential tokenization and processing # Sequential tokenization and processing
for i in range(batch_size): for i in range(batch_size):
tmp_obj = obj[i] tmp_obj = obj[i]
tokenized_obj = await self._tokenize_one_request(tmp_obj) tokenized_obj = await self._tokenize_one_request(tmp_obj)
self._send_one_request(tmp_obj, tokenized_obj, created_time) state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
generators.append(self._wait_one_response(tmp_obj, request)) generators.append(self._wait_one_response(tmp_obj, state, request))
rids.append(tmp_obj.rid) rids.append(tmp_obj.rid)
else: else:
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal. # FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
...@@ -743,8 +744,8 @@ class TokenizerManager: ...@@ -743,8 +744,8 @@ class TokenizerManager:
tokenized_obj.sampling_params = copy.copy(tokenized_obj.sampling_params) tokenized_obj.sampling_params = copy.copy(tokenized_obj.sampling_params)
tokenized_obj.sampling_params.max_new_tokens = 0 tokenized_obj.sampling_params.max_new_tokens = 0
tokenized_obj.stream = False tokenized_obj.stream = False
self._send_one_request(tmp_obj, tokenized_obj, created_time) state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
await self._wait_one_response(tmp_obj, request).__anext__() await self._wait_one_response(tmp_obj, state, request).__anext__()
# Expand requests, assign new rids for them, and send them # Expand requests, assign new rids for them, and send them
for i in range(batch_size): for i in range(batch_size):
...@@ -752,8 +753,8 @@ class TokenizerManager: ...@@ -752,8 +753,8 @@ class TokenizerManager:
tmp_obj = copy.copy(objs[i]) tmp_obj = copy.copy(objs[i])
tokenized_obj = copy.copy(tokenized_objs[i]) tokenized_obj = copy.copy(tokenized_objs[i])
tokenized_obj.rid = tmp_obj.regenerate_rid() tokenized_obj.rid = tmp_obj.regenerate_rid()
self._send_one_request(tmp_obj, tokenized_obj, created_time) state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
generators.append(self._wait_one_response(tmp_obj, request)) generators.append(self._wait_one_response(tmp_obj, state, request))
rids.append(tmp_obj.rid) rids.append(tmp_obj.rid)
# Wait for all requests # Wait for all requests
...@@ -789,6 +790,9 @@ class TokenizerManager: ...@@ -789,6 +790,9 @@ class TokenizerManager:
req = AbortReq(rid) req = AbortReq(rid)
self.send_to_scheduler.send_pyobj(req) self.send_to_scheduler.send_pyobj(req)
if self.enable_metrics:
self.metrics_collector.observe_one_aborted_request()
async def start_profile( async def start_profile(
self, self,
output_dir: Optional[str] = None, output_dir: Optional[str] = None,
......
...@@ -35,10 +35,6 @@ def validate_input_length( ...@@ -35,10 +35,6 @@ def validate_input_length(
f"the maximum allowed length ({max_req_input_len} tokens). " f"the maximum allowed length ({max_req_input_len} tokens). "
f"Use a shorter input or enable --allow-auto-truncate." f"Use a shorter input or enable --allow-auto-truncate."
) )
logger.error(error_msg)
req.finished_reason = FINISH_ABORT(
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
)
return error_msg return error_msg
return None return None
...@@ -402,6 +402,12 @@ class TokenizerMetricsCollector: ...@@ -402,6 +402,12 @@ class TokenizerMetricsCollector:
labelnames=labels.keys(), labelnames=labels.keys(),
) )
self.num_aborted_requests_total = Counter(
name="sglang:num_aborted_requests",
documentation="Number of requests aborted.",
labelnames=labels.keys(),
)
if bucket_time_to_first_token is None: if bucket_time_to_first_token is None:
bucket_time_to_first_token = [ bucket_time_to_first_token = [
0.1, 0.1,
...@@ -533,3 +539,6 @@ class TokenizerMetricsCollector: ...@@ -533,3 +539,6 @@ class TokenizerMetricsCollector:
if adjusted_interval <= bound: if adjusted_interval <= bound:
his._buckets[i].inc(num_new_tokens) his._buckets[i].inc(num_new_tokens)
break break
def observe_one_aborted_request(self):
self.num_aborted_requests_total.labels(**self.labels).inc(1)
...@@ -24,7 +24,6 @@ from typing import TYPE_CHECKING, Callable, Optional, Union ...@@ -24,7 +24,6 @@ from typing import TYPE_CHECKING, Callable, Optional, Union
import torch import torch
import tqdm import tqdm
from sglang.srt import two_batch_overlap
from sglang.srt.custom_op import CustomOp from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
...@@ -133,28 +132,27 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner): ...@@ -133,28 +132,27 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
if capture_bs is None: if capture_bs is None:
if server_args.speculative_algorithm is None: if server_args.speculative_algorithm is None:
if server_args.disable_cuda_graph_padding: if server_args.disable_cuda_graph_padding:
capture_bs = list(range(1, 33)) + list(range(40, 161, 16)) capture_bs = list(range(1, 33)) + list(range(48, 161, 16))
else: else:
capture_bs = [1, 2, 4, 8] + list(range(16, 161, 8)) capture_bs = [1, 2, 4, 8] + list(range(16, 161, 8))
else: else:
# Since speculative decoding requires more cuda graph memory, we # Since speculative decoding requires more cuda graph memory, we
# capture less. # capture less.
capture_bs = ( capture_bs = (
list(range(1, 9)) + list(range(10, 33, 2)) + list(range(40, 161, 16)) list(range(1, 9))
+ list(range(10, 33, 2))
+ list(range(40, 64, 8))
+ list(range(80, 161, 16))
) )
gpu_mem = get_device_memory_capacity() gpu_mem = get_device_memory_capacity()
if gpu_mem is not None and gpu_mem > 96 * 1024: if gpu_mem is not None and gpu_mem > 96 * 1024:
capture_bs += list(range(160, 257, 8)) capture_bs += list(range(160, 257, 8))
if gpu_mem is not None and gpu_mem > 180 * 1000:
capture_bs += list(range(256, 528, 16))
if max(capture_bs) > model_runner.req_to_token_pool.size: if max(capture_bs) > model_runner.req_to_token_pool.size:
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests # In some cases (e.g., with a small GPU or --max-running-requests), the #max-running-requests
# is very small. We add more values here to make sure we capture the maximum bs. # is very small. We add more values here to make sure we capture the maximum bs.
capture_bs += [model_runner.req_to_token_pool.size - 1] + [ capture_bs += [model_runner.req_to_token_pool.size]
model_runner.req_to_token_pool.size
]
if server_args.enable_two_batch_overlap: if server_args.enable_two_batch_overlap:
capture_bs = [bs for bs in capture_bs if bs >= 2] capture_bs = [bs for bs in capture_bs if bs >= 2]
...@@ -167,7 +165,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner): ...@@ -167,7 +165,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
) )
capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size] capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size]
capture_bs = list(sorted(set(capture_bs))) capture_bs = list(sorted(set(capture_bs)))
assert len(capture_bs) > 0 and capture_bs[0] > 0 assert len(capture_bs) > 0 and capture_bs[0] > 0, f"{capture_bs=}"
compile_bs = ( compile_bs = (
[bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs] [bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs]
if server_args.enable_torch_compile if server_args.enable_torch_compile
......
...@@ -918,7 +918,7 @@ class ModelRunner: ...@@ -918,7 +918,7 @@ class ModelRunner:
if self.req_to_token_pool is None: if self.req_to_token_pool is None:
self.req_to_token_pool = ReqToTokenPool( self.req_to_token_pool = ReqToTokenPool(
size=max_num_reqs + 1, size=max_num_reqs,
max_context_len=self.model_config.context_len + 4, max_context_len=self.model_config.context_len + 4,
device=self.device, device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver, enable_memory_saver=self.server_args.enable_memory_saver,
......
...@@ -2055,6 +2055,12 @@ is_ampere_with_cuda_12_3 = lambda: _check(8) ...@@ -2055,6 +2055,12 @@ is_ampere_with_cuda_12_3 = lambda: _check(8)
is_hopper_with_cuda_12_3 = lambda: _check(9) is_hopper_with_cuda_12_3 = lambda: _check(9)
def is_blackwell():
if not is_cuda():
return False
return torch.cuda.get_device_capability()[0] == 10
def get_free_port(): def get_free_port():
# try ipv4 # try ipv4
try: try:
......
...@@ -127,6 +127,10 @@ def send_one_prompt(args): ...@@ -127,6 +127,10 @@ def send_one_prompt(args):
if args.batch_size > 1: if args.batch_size > 1:
ret = ret[0] ret = ret[0]
if response.status_code != 200:
print(ret)
return 0, 0
latency = ret["meta_info"]["e2e_latency"] latency = ret["meta_info"]["e2e_latency"]
if "spec_verify_ct" in ret["meta_info"]: if "spec_verify_ct" in ret["meta_info"]:
......
...@@ -881,20 +881,24 @@ def calculate_rouge_l(output_strs_list1, output_strs_list2): ...@@ -881,20 +881,24 @@ def calculate_rouge_l(output_strs_list1, output_strs_list2):
return rouge_l_scores return rouge_l_scores
STDERR_FILENAME = "stderr.txt" STDERR_FILENAME = "/tmp/stderr.txt"
STDOUT_FILENAME = "stdout.txt" STDOUT_FILENAME = "/tmp/stdout.txt"
def read_output(output_lines: List[str], filename: str = STDERR_FILENAME): def read_output(output_lines: List[str], filename: str = STDERR_FILENAME):
"""Print the output in real time with another thread.""" """Print the output in real time with another thread."""
while not os.path.exists(filename): while not os.path.exists(filename):
time.sleep(1) time.sleep(0.01)
pt = 0 pt = 0
while pt >= 0: while pt >= 0:
if pt > 0 and not os.path.exists(filename): if pt > 0 and not os.path.exists(filename):
break break
try:
lines = open(filename).readlines() lines = open(filename).readlines()
except FileNotFoundError:
print(f"{pt=}, {os.path.exists(filename)=}")
raise
for line in lines[pt:]: for line in lines[pt:]:
print(line, end="", flush=True) print(line, end="", flush=True)
output_lines.append(line) output_lines.append(line)
......
#!/bin/bash #!/bin/bash
# Show current GPU status if [ "$1" = "rocm" ]; then
nvidia-smi echo "Running in ROCm mode"
# Clean SGLang processes # Clean SGLang processes
pgrep -f 'sglang::|sglang\.launch_server|sglang\.bench|sglang\.data_parallel|sglang\.srt' | xargs -r kill -9 pgrep -f 'sglang::|sglang\.launch_server|sglang\.bench|sglang\.data_parallel|sglang\.srt' | xargs -r kill -9
# Clean all GPU processes if any argument is provided else
if [ $# -gt 0 ]; then # Show current GPU status
nvidia-smi
# Clean SGLang processes
pgrep -f 'sglang::|sglang\.launch_server|sglang\.bench|sglang\.data_parallel|sglang\.srt' | xargs -r kill -9
# Clean all GPU processes if any argument is provided
if [ $# -gt 0 ]; then
# Check if sudo is available # Check if sudo is available
if command -v sudo >/dev/null 2>&1; then if command -v sudo >/dev/null 2>&1; then
sudo apt-get update sudo apt-get update
...@@ -18,8 +25,9 @@ if [ $# -gt 0 ]; then ...@@ -18,8 +25,9 @@ if [ $# -gt 0 ]; then
fi fi
kill -9 $(nvidia-smi | sed -n '/Processes:/,$p' | grep " [0-9]" | awk '{print $5}') 2>/dev/null kill -9 $(nvidia-smi | sed -n '/Processes:/,$p' | grep " [0-9]" | awk '{print $5}') 2>/dev/null
lsof /dev/nvidia* | awk '{print $2}' | xargs kill -9 2>/dev/null lsof /dev/nvidia* | awk '{print $2}' | xargs kill -9 2>/dev/null
fi fi
# Show GPU status after clean up # Show GPU status after clean up
nvidia-smi nvidia-smi
fi
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment