Unverified Commit de167cf5 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix request abortion (#6184)

parent 4319978c
...@@ -56,7 +56,7 @@ jobs: ...@@ -56,7 +56,7 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
part: [0, 1, 2, 3, 4, 5, 6, 7] part: [0, 1, 2, 3, 4, 5, 6, 7, 8]
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
...@@ -69,7 +69,7 @@ jobs: ...@@ -69,7 +69,7 @@ jobs:
timeout-minutes: 30 timeout-minutes: 30
run: | run: |
cd test/srt cd test/srt
python3 run_suite.py --suite per-commit --auto-partition-id ${{ matrix.part }} --auto-partition-size 8 python3 run_suite.py --suite per-commit --auto-partition-id ${{ matrix.part }} --auto-partition-size 9
unit-test-backend-2-gpu: unit-test-backend-2-gpu:
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
......
...@@ -49,6 +49,7 @@ from sglang.srt.disaggregation.utils import ( ...@@ -49,6 +49,7 @@ from sglang.srt.disaggregation.utils import (
from sglang.srt.entrypoints.engine import _launch_subprocesses from sglang.srt.entrypoints.engine import _launch_subprocesses
from sglang.srt.function_call_parser import FunctionCallParser from sglang.srt.function_call_parser import FunctionCallParser
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
AbortReq,
CloseSessionReqInput, CloseSessionReqInput,
ConfigureLoggingReq, ConfigureLoggingReq,
EmbeddingReqInput, EmbeddingReqInput,
...@@ -539,6 +540,16 @@ async def configure_logging(obj: ConfigureLoggingReq, request: Request): ...@@ -539,6 +540,16 @@ async def configure_logging(obj: ConfigureLoggingReq, request: Request):
return Response(status_code=200) return Response(status_code=200)
@app.post("/abort_request")
async def abort_request(obj: AbortReq, request: Request):
"""Abort a request."""
try:
_global_state.tokenizer_manager.abort_request(rid=obj.rid)
return Response(status_code=200)
except Exception as e:
return _create_error_response(e)
@app.post("/parse_function_call") @app.post("/parse_function_call")
async def parse_function_call_request(obj: ParseFunctionCallReq, request: Request): async def parse_function_call_request(obj: ParseFunctionCallReq, request: Request):
""" """
......
from __future__ import annotations from __future__ import annotations
import hashlib
from enum import Enum, auto
# Copyright 2023-2024 SGLang Team # Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -30,12 +27,16 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch ...@@ -30,12 +27,16 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
It will be transformed from CPU scheduler to GPU model runner. It will be transformed from CPU scheduler to GPU model runner.
- ForwardBatch is managed by `model_runner.py::ModelRunner`. - ForwardBatch is managed by `model_runner.py::ModelRunner`.
It contains low-level tensor data. Most of the data consists of GPU tensors. It contains low-level tensor data. Most of the data consists of GPU tensors.
TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing it in the future.
""" """
import copy import copy
import dataclasses import dataclasses
import hashlib
import logging import logging
import threading import threading
from enum import Enum, auto
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
import numpy as np import numpy as np
...@@ -134,9 +135,9 @@ class FINISH_LENGTH(BaseFinishReason): ...@@ -134,9 +135,9 @@ class FINISH_LENGTH(BaseFinishReason):
class FINISH_ABORT(BaseFinishReason): class FINISH_ABORT(BaseFinishReason):
def __init__(self, message="Unknown error", status_code=None, err_type=None): def __init__(self, message=None, status_code=None, err_type=None):
super().__init__(is_error=True) super().__init__(is_error=True)
self.message = message self.message = message or "Aborted"
self.status_code = status_code self.status_code = status_code
self.err_type = err_type self.err_type = err_type
...@@ -441,11 +442,13 @@ class Req: ...@@ -441,11 +442,13 @@ class Req:
# Check finish # Check finish
self.tokenizer = None self.tokenizer = None
self.finished_reason = None self.finished_reason = None
# Whether this request has finished output
self.finished_output = None
# If we want to abort the request in the middle of the event loop, set this to true # If we want to abort the request in the middle of the event loop, set this to true
# Note: We should never set finished_reason in the middle, the req will get filtered and never respond # Note: We should never set finished_reason in the middle, the req will get filtered and never respond
self.to_abort = False self.to_abort = False
# This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop # This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
self.to_abort_message: str = "Unknown error" self.to_abort_message: str = None
self.stream = stream self.stream = stream
self.eos_token_ids = eos_token_ids self.eos_token_ids = eos_token_ids
...@@ -546,8 +549,6 @@ class Req: ...@@ -546,8 +549,6 @@ class Req:
self.bootstrap_room: Optional[int] = bootstrap_room self.bootstrap_room: Optional[int] = bootstrap_room
self.disagg_kv_sender: Optional[BaseKVSender] = None self.disagg_kv_sender: Optional[BaseKVSender] = None
# used for warmup because we don't have a pair yet when init
self.skip_kv_transfer: bool = False
# the start index of the sent kv cache # the start index of the sent kv cache
# We want to send it chunk by chunk for chunked prefill. # We want to send it chunk by chunk for chunked prefill.
# After every chunk forward, we do the following: # After every chunk forward, we do the following:
...@@ -555,15 +556,15 @@ class Req: ...@@ -555,15 +556,15 @@ class Req:
# start_send_idx = len(req.fill_ids) # start_send_idx = len(req.fill_ids)
self.start_send_idx: int = 0 self.start_send_idx: int = 0
self.metadata_buffer_index: int = -1
# The first output_id transferred from prefill instance.
self.transferred_output_id: Optional[int] = None
# For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap # For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap
# This is because kv is not ready in `process_prefill_chunk`. # This is because kv is not ready in `process_prefill_chunk`.
# We use `tmp_end_idx` to store the end index of the kv cache to send. # We use `tmp_end_idx` to store the end index of the kv cache to send.
self.tmp_end_idx: int = -1 self.tmp_end_idx: int = -1
self.metadata_buffer_index: int = -1
# The first output_id transferred from prefill instance.
self.transferred_output_id: Optional[int] = None
@property @property
def seqlen(self): def seqlen(self):
return len(self.origin_input_ids) + len(self.output_ids) return len(self.origin_input_ids) + len(self.output_ids)
...@@ -697,13 +698,29 @@ class Req: ...@@ -697,13 +698,29 @@ class Req:
self.req_pool_idx = None self.req_pool_idx = None
self.already_computed = 0 self.already_computed = 0
def offload_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
token_indices = req_to_token_pool.req_to_token[
self.req_pool_idx, : self.seqlen - 1
]
self.kv_cache_cpu = token_to_kv_pool_allocator.get_cpu_copy(token_indices)
def load_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
token_indices = req_to_token_pool.req_to_token[
self.req_pool_idx, : self.seqlen - 1
]
token_to_kv_pool_allocator.load_cpu_copy(self.kv_cache_cpu, token_indices)
del self.kv_cache_cpu
def __repr__(self): def __repr__(self):
return ( return (
f"Req(rid={self.rid}, " f"Req(rid={self.rid}, "
f"input_ids={self.origin_input_ids}, output_ids={self.output_ids})" f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}, "
f"{self.grammar=}, "
f"{self.sampling_params=})"
) )
# Batch id
bid = 0 bid = 0
...@@ -1447,7 +1464,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1447,7 +1464,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
i i
for i in range(len(self.reqs)) for i in range(len(self.reqs))
if not self.reqs[i].finished() if not self.reqs[i].finished()
and not self.reqs[i] in chunked_req_to_exclude and self.reqs[i] not in chunked_req_to_exclude
] ]
if keep_indices is None or len(keep_indices) == 0: if keep_indices is None or len(keep_indices) == 0:
......
...@@ -20,7 +20,6 @@ import signal ...@@ -20,7 +20,6 @@ import signal
import sys import sys
import threading import threading
import time import time
import warnings
from collections import defaultdict, deque from collections import defaultdict, deque
from concurrent import futures from concurrent import futures
from dataclasses import dataclass from dataclasses import dataclass
...@@ -121,11 +120,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache ...@@ -121,11 +120,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
ForwardBatch,
ForwardMode,
PPProxyTensors,
)
from sglang.srt.reasoning_parser import ReasoningParser from sglang.srt.reasoning_parser import ReasoningParser
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
...@@ -135,6 +130,7 @@ from sglang.srt.utils import ( ...@@ -135,6 +130,7 @@ from sglang.srt.utils import (
broadcast_pyobj, broadcast_pyobj,
configure_logger, configure_logger,
crash_on_warnings, crash_on_warnings,
disable_request_logging,
get_bool_env_var, get_bool_env_var,
get_zmq_socket, get_zmq_socket,
kill_itself_when_parent_died, kill_itself_when_parent_died,
...@@ -907,19 +903,6 @@ class Scheduler( ...@@ -907,19 +903,6 @@ class Scheduler(
fake_input_ids = [1] * seq_length fake_input_ids = [1] * seq_length
recv_req.input_ids = fake_input_ids recv_req.input_ids = fake_input_ids
# Handle custom logit processor passed to the request
custom_logit_processor = recv_req.custom_logit_processor
if (
not self.server_args.enable_custom_logit_processor
and custom_logit_processor is not None
):
logger.warning(
"The SGLang server is not configured to enable custom logit processor."
"The custom logit processor passed in will be ignored."
"Please set --enable-custom-logits-processor to enable this feature."
)
custom_logit_processor = None
if recv_req.bootstrap_port is None: if recv_req.bootstrap_port is None:
# Use default bootstrap port # Use default bootstrap port
recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port
...@@ -935,7 +918,7 @@ class Scheduler( ...@@ -935,7 +918,7 @@ class Scheduler(
stream=recv_req.stream, stream=recv_req.stream,
lora_path=recv_req.lora_path, lora_path=recv_req.lora_path,
input_embeds=recv_req.input_embeds, input_embeds=recv_req.input_embeds,
custom_logit_processor=custom_logit_processor, custom_logit_processor=recv_req.custom_logit_processor,
return_hidden_states=recv_req.return_hidden_states, return_hidden_states=recv_req.return_hidden_states,
eos_token_ids=self.model_config.hf_eos_token_id, eos_token_ids=self.model_config.hf_eos_token_id,
bootstrap_host=recv_req.bootstrap_host, bootstrap_host=recv_req.bootstrap_host,
...@@ -1246,9 +1229,7 @@ class Scheduler( ...@@ -1246,9 +1229,7 @@ class Scheduler(
f"{self.token_to_kv_pool_allocator.available_size()=}\n" f"{self.token_to_kv_pool_allocator.available_size()=}\n"
f"{self.tree_cache.evictable_size()=}\n" f"{self.tree_cache.evictable_size()=}\n"
) )
warnings.warn(msg) raise ValueError(msg)
if crash_on_warnings():
raise ValueError(msg)
if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size: if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
msg = ( msg = (
...@@ -1256,9 +1237,7 @@ class Scheduler( ...@@ -1256,9 +1237,7 @@ class Scheduler(
f"available_size={len(self.req_to_token_pool.free_slots)}, " f"available_size={len(self.req_to_token_pool.free_slots)}, "
f"total_size={self.req_to_token_pool.size}\n" f"total_size={self.req_to_token_pool.size}\n"
) )
warnings.warn(msg) raise ValueError(msg)
if crash_on_warnings():
raise ValueError(msg)
if ( if (
self.enable_metrics self.enable_metrics
...@@ -1774,24 +1753,27 @@ class Scheduler( ...@@ -1774,24 +1753,27 @@ class Scheduler(
if self.cur_batch is not None: if self.cur_batch is not None:
if self.watchdog_last_forward_ct == self.forward_ct: if self.watchdog_last_forward_ct == self.forward_ct:
if current > self.watchdog_last_time + self.watchdog_timeout: if current > self.watchdog_last_time + self.watchdog_timeout:
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
break break
else: else:
self.watchdog_last_forward_ct = self.forward_ct self.watchdog_last_forward_ct = self.forward_ct
self.watchdog_last_time = current self.watchdog_last_time = current
time.sleep(self.watchdog_timeout // 2) time.sleep(self.watchdog_timeout // 2)
# Print batch size and memory pool info to check whether there are de-sync issues. if not disable_request_logging():
logger.error( # Print batch size and memory pool info to check whether there are de-sync issues.
f"{self.cur_batch.batch_size()=}, " logger.error(
f"{self.cur_batch.reqs=}, " f"{self.cur_batch.batch_size()=}, "
f"{self.token_to_kv_pool_allocator.available_size()=}, " f"{self.cur_batch.reqs=}, "
f"{self.tree_cache.evictable_size()=}, " f"{self.token_to_kv_pool_allocator.available_size()=}, "
) f"{self.tree_cache.evictable_size()=}, "
# Wait for some time so that the parent process can print the error. )
pyspy_dump_schedulers() pyspy_dump_schedulers()
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
print(file=sys.stderr, flush=True) print(file=sys.stderr, flush=True)
print(file=sys.stdout, flush=True) print(file=sys.stdout, flush=True)
# Wait for some time so that the parent process can print the error.
time.sleep(5) time.sleep(5)
self.parent_process.send_signal(signal.SIGQUIT) self.parent_process.send_signal(signal.SIGQUIT)
...@@ -1923,25 +1905,30 @@ class Scheduler( ...@@ -1923,25 +1905,30 @@ class Scheduler(
) )
def abort_request(self, recv_req: AbortReq): def abort_request(self, recv_req: AbortReq):
# TODO(lmzheng): abort the requests in the grammar queue.
# Delete requests in the waiting queue # Delete requests in the waiting queue
to_del = [] to_del = []
for i, req in enumerate(self.waiting_queue): for i, req in enumerate(self.waiting_queue):
if req.rid.startswith(recv_req.rid): if req.rid.startswith(recv_req.rid):
to_del.append(i) to_del.append(i)
break
# Sort in reverse order to avoid index issues when deleting # Sort in reverse order to avoid index issues when deleting
for i in sorted(to_del, reverse=True): for i in reversed(to_del):
req = self.waiting_queue.pop(i) req = self.waiting_queue.pop(i)
self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
logger.debug(f"Abort queued request. {req.rid=}") logger.debug(f"Abort queued request. {req.rid=}")
return
# Delete requests in the running batch # Delete requests in the running batch
for req in self.running_batch.reqs: if self.cur_batch is self.running_batch or self.cur_batch is None:
reqs = self.running_batch.reqs
else:
reqs = self.running_batch.reqs + self.cur_batch.reqs
for req in reqs:
if req.rid.startswith(recv_req.rid) and not req.finished(): if req.rid.startswith(recv_req.rid) and not req.finished():
logger.debug(f"Abort running request. {req.rid=}") logger.debug(f"Abort running request. {req.rid=}")
req.to_abort = True req.to_abort = True
return
def _pause_engine(self) -> Tuple[List[Req], int]: def _pause_engine(self) -> Tuple[List[Req], int]:
raise NotImplementedError() raise NotImplementedError()
......
...@@ -15,6 +15,8 @@ if TYPE_CHECKING: ...@@ -15,6 +15,8 @@ if TYPE_CHECKING:
Scheduler, Scheduler,
) )
DEFAULT_FORCE_STREAM_INTERVAL = 50
class SchedulerOutputProcessorMixin: class SchedulerOutputProcessorMixin:
""" """
...@@ -512,19 +514,26 @@ class SchedulerOutputProcessorMixin: ...@@ -512,19 +514,26 @@ class SchedulerOutputProcessorMixin:
if self.model_config.is_multimodal_gen and req.to_abort: if self.model_config.is_multimodal_gen and req.to_abort:
continue continue
if ( if req.finished():
req.finished() if req.finished_output:
# If stream, follow the given stream_interval # With the overlap schedule, a request will try to output twice and hit this line twice
or (req.stream and len(req.output_ids) % self.stream_interval == 0) # because of the one additional delayed token. This "continue" prevented the dummy output.
# If not stream, we still want to output some tokens to get the benefit of incremental decoding. continue
# TODO(lianmin): this is wrong for speculative decoding because len(req.output_ids) does not req.finished_output = True
# always increase one-by-one. should_output = True
or ( else:
not req.stream if req.stream:
and len(req.output_ids) % 50 == 0 stream_interval = (
and not self.model_config.is_multimodal_gen req.sampling_params.stream_interval or self.stream_interval
) )
): should_output = len(req.output_ids) % stream_interval == 0
else:
should_output = (
len(req.output_ids) % DEFAULT_FORCE_STREAM_INTERVAL == 0
and not self.model_config.is_multimodal_gen
)
if should_output:
rids.append(req.rid) rids.append(req.rid)
finished_reasons.append( finished_reasons.append(
req.finished_reason.to_json() if req.finished_reason else None req.finished_reason.to_json() if req.finished_reason else None
......
...@@ -288,6 +288,7 @@ class TokenizerManager: ...@@ -288,6 +288,7 @@ class TokenizerManager:
), ),
self._handle_batch_output, self._handle_batch_output,
), ),
(AbortReq, self._handle_abort_req),
(OpenSessionReqOutput, self._handle_open_session_req_output), (OpenSessionReqOutput, self._handle_open_session_req_output),
( (
UpdateWeightFromDiskReqOutput, UpdateWeightFromDiskReqOutput,
...@@ -341,13 +342,14 @@ class TokenizerManager: ...@@ -341,13 +342,14 @@ class TokenizerManager:
] ]
) )
# For pd disaggregtion
self.disaggregation_mode = DisaggregationMode( self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode self.server_args.disaggregation_mode
) )
self.transfer_backend = TransferBackend( self.transfer_backend = TransferBackend(
self.server_args.disaggregation_transfer_backend self.server_args.disaggregation_transfer_backend
) )
# for disaggregtion, start kv boostrap server on prefill # Start kv boostrap server on prefill
if self.disaggregation_mode == DisaggregationMode.PREFILL: if self.disaggregation_mode == DisaggregationMode.PREFILL:
# only start bootstrap server on prefill tm # only start bootstrap server on prefill tm
kv_bootstrap_server_class = get_kv_class( kv_bootstrap_server_class = get_kv_class(
...@@ -482,6 +484,14 @@ class TokenizerManager: ...@@ -482,6 +484,14 @@ class TokenizerManager:
session_params = ( session_params = (
SessionParams(**obj.session_params) if obj.session_params else None SessionParams(**obj.session_params) if obj.session_params else None
) )
if (
obj.custom_logit_processor
and not self.server_args.enable_custom_logit_processor
):
raise ValueError(
"The server is not configured to enable custom logit processor. "
"Please set `--enable-custom-logits-processor` to enable this feature."
)
sampling_params = SamplingParams(**obj.sampling_params) sampling_params = SamplingParams(**obj.sampling_params)
sampling_params.normalize(self.tokenizer) sampling_params.normalize(self.tokenizer)
...@@ -570,9 +580,9 @@ class TokenizerManager: ...@@ -570,9 +580,9 @@ class TokenizerManager:
tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput], tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
created_time: Optional[float] = None, created_time: Optional[float] = None,
): ):
self.send_to_scheduler.send_pyobj(tokenized_obj)
state = ReqState([], False, asyncio.Event(), obj, created_time=created_time) state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
self.rid_to_state[obj.rid] = state self.rid_to_state[obj.rid] = state
self.send_to_scheduler.send_pyobj(tokenized_obj)
async def _wait_one_response( async def _wait_one_response(
self, self,
...@@ -587,10 +597,11 @@ class TokenizerManager: ...@@ -587,10 +597,11 @@ class TokenizerManager:
await asyncio.wait_for(state.event.wait(), timeout=4) await asyncio.wait_for(state.event.wait(), timeout=4)
except asyncio.TimeoutError: except asyncio.TimeoutError:
if request is not None and await request.is_disconnected(): if request is not None and await request.is_disconnected():
# Abort the request for disconnected requests (non-streaming, waiting queue)
self.abort_request(obj.rid) self.abort_request(obj.rid)
# Use exception to kill the whole call stack and asyncio task
raise ValueError( raise ValueError(
"Request is disconnected from the client side. " f"Request is disconnected from the client side (type 1). Abort request {obj.rid=}"
f"Abort request {obj.rid}"
) )
continue continue
...@@ -605,7 +616,6 @@ class TokenizerManager: ...@@ -605,7 +616,6 @@ class TokenizerManager:
else: else:
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}" msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
logger.info(msg) logger.info(msg)
del self.rid_to_state[obj.rid]
# Check if this was an abort/error created by scheduler # Check if this was an abort/error created by scheduler
if isinstance(out["meta_info"].get("finish_reason"), dict): if isinstance(out["meta_info"].get("finish_reason"), dict):
...@@ -625,10 +635,11 @@ class TokenizerManager: ...@@ -625,10 +635,11 @@ class TokenizerManager:
yield out yield out
else: else:
if request is not None and await request.is_disconnected(): if request is not None and await request.is_disconnected():
# Abort the request for disconnected requests (non-streaming, running)
self.abort_request(obj.rid) self.abort_request(obj.rid)
# Use exception to kill the whole call stack and asyncio task
raise ValueError( raise ValueError(
"Request is disconnected from the client side. " f"Request is disconnected from the client side (type 3). Abort request {obj.rid=}"
f"Abort request {obj.rid}"
) )
async def _handle_batch_request( async def _handle_batch_request(
...@@ -728,7 +739,6 @@ class TokenizerManager: ...@@ -728,7 +739,6 @@ class TokenizerManager:
def abort_request(self, rid: str): def abort_request(self, rid: str):
if rid not in self.rid_to_state: if rid not in self.rid_to_state:
return return
del self.rid_to_state[rid]
req = AbortReq(rid) req = AbortReq(rid)
self.send_to_scheduler.send_pyobj(req) self.send_to_scheduler.send_pyobj(req)
...@@ -964,7 +974,7 @@ class TokenizerManager: ...@@ -964,7 +974,7 @@ class TokenizerManager:
def create_abort_task(self, obj: GenerateReqInput): def create_abort_task(self, obj: GenerateReqInput):
# Abort the request if the client is disconnected. # Abort the request if the client is disconnected.
async def abort_request(): async def abort_request():
await asyncio.sleep(1) await asyncio.sleep(2)
if obj.is_single: if obj.is_single:
self.abort_request(obj.rid) self.abort_request(obj.rid)
else: else:
...@@ -1035,6 +1045,9 @@ class TokenizerManager: ...@@ -1035,6 +1045,9 @@ class TokenizerManager:
for i, rid in enumerate(recv_obj.rids): for i, rid in enumerate(recv_obj.rids):
state = self.rid_to_state.get(rid, None) state = self.rid_to_state.get(rid, None)
if state is None: if state is None:
logger.error(
f"Received output for {rid=} but the state was deleted in TokenizerManager."
)
continue continue
# Build meta_info and return value # Build meta_info and return value
...@@ -1098,6 +1111,7 @@ class TokenizerManager: ...@@ -1098,6 +1111,7 @@ class TokenizerManager:
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i] meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
state.finished_time = time.time() state.finished_time = time.time()
meta_info["e2e_latency"] = state.finished_time - state.created_time meta_info["e2e_latency"] = state.finished_time - state.created_time
del self.rid_to_state[rid]
state.out_list.append(out_dict) state.out_list.append(out_dict)
state.event.set() state.event.set()
...@@ -1246,6 +1260,9 @@ class TokenizerManager: ...@@ -1246,6 +1260,9 @@ class TokenizerManager:
# Schedule the task to run in the background without awaiting it # Schedule the task to run in the background without awaiting it
asyncio.create_task(asyncio.to_thread(background_task)) asyncio.create_task(asyncio.to_thread(background_task))
def _handle_abort_req(self, recv_obj):
self.rid_to_state.pop(recv_obj.rid)
def _handle_open_session_req_output(self, recv_obj): def _handle_open_session_req_output(self, recv_obj):
self.session_futures[recv_obj.session_id].set_result( self.session_futures[recv_obj.session_id].set_result(
recv_obj.session_id if recv_obj.success else None recv_obj.session_id if recv_obj.success else None
...@@ -1325,3 +1342,15 @@ class _Communicator(Generic[T]): ...@@ -1325,3 +1342,15 @@ class _Communicator(Generic[T]):
self._result_values.append(recv_obj) self._result_values.append(recv_obj)
if len(self._result_values) == self._fan_out: if len(self._result_values) == self._fan_out:
self._result_event.set() self._result_event.set()
# Note: request abort handling logic
# We should handle all of the following cases correctly.
#
# | entrypoint | is_streaming | status | abort engine | cancel asyncio task | rid_to_state |
# | ---------- | ------------ | --------------- | --------------- | --------------------- | --------------------------- |
# | http | yes | waiting queue | background task | fast api | del in _handle_abort_req |
# | http | yes | running | background task | fast api | del in _handle_batch_output |
# | http | no | waiting queue | type 1 | type 1 exception | del in _handle_abort_req |
# | http | no | running | type 3 | type 3 exception | del in _handle_batch_output |
#
...@@ -50,6 +50,7 @@ class SamplingParams: ...@@ -50,6 +50,7 @@ class SamplingParams:
spaces_between_special_tokens: bool = True, spaces_between_special_tokens: bool = True,
no_stop_trim: bool = False, no_stop_trim: bool = False,
custom_params: Optional[Dict[str, Any]] = None, custom_params: Optional[Dict[str, Any]] = None,
stream_interval: Optional[int] = None,
) -> None: ) -> None:
self.max_new_tokens = max_new_tokens self.max_new_tokens = max_new_tokens
self.stop_strs = stop self.stop_strs = stop
...@@ -75,6 +76,7 @@ class SamplingParams: ...@@ -75,6 +76,7 @@ class SamplingParams:
self.spaces_between_special_tokens = spaces_between_special_tokens self.spaces_between_special_tokens = spaces_between_special_tokens
self.no_stop_trim = no_stop_trim self.no_stop_trim = no_stop_trim
self.custom_params = custom_params self.custom_params = custom_params
self.stream_interval = stream_interval
# Process some special cases # Process some special cases
if 0 <= self.temperature < _SAMPLING_EPS: if 0 <= self.temperature < _SAMPLING_EPS:
......
...@@ -27,6 +27,7 @@ class BenchArgs: ...@@ -27,6 +27,7 @@ class BenchArgs:
"Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:" "Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:"
) )
image: bool = False image: bool = False
many_images: bool = False
stream: bool = False stream: bool = False
@staticmethod @staticmethod
...@@ -48,6 +49,7 @@ class BenchArgs: ...@@ -48,6 +49,7 @@ class BenchArgs:
parser.add_argument("--return-logprob", action="store_true") parser.add_argument("--return-logprob", action="store_true")
parser.add_argument("--prompt", type=str, default=BenchArgs.prompt) parser.add_argument("--prompt", type=str, default=BenchArgs.prompt)
parser.add_argument("--image", action="store_true") parser.add_argument("--image", action="store_true")
parser.add_argument("--many-images", action="store_true")
parser.add_argument("--stream", action="store_true") parser.add_argument("--stream", action="store_true")
@classmethod @classmethod
...@@ -62,6 +64,17 @@ def send_one_prompt(args): ...@@ -62,6 +64,17 @@ def send_one_prompt(args):
"Human: Describe this image in a very short sentence.\n\nAssistant:" "Human: Describe this image in a very short sentence.\n\nAssistant:"
) )
image_data = "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png" image_data = "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png"
elif args.many_images:
args.prompt = (
"Human: I have one reference image and many images."
"Describe their relationship in a very short sentence.\n\nAssistant:"
)
image_data = [
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png",
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png",
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png",
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png",
]
else: else:
image_data = None image_data = None
...@@ -74,9 +87,6 @@ def send_one_prompt(args): ...@@ -74,9 +87,6 @@ def send_one_prompt(args):
"Write in a format of json.\nAssistant:" "Write in a format of json.\nAssistant:"
) )
json_schema = "$$ANY$$" json_schema = "$$ANY$$"
json_schema = (
'{"type": "object", "properties": {"population": {"type": "integer"}}}'
)
else: else:
json_schema = None json_schema = None
......
...@@ -190,7 +190,7 @@ class TestBenchServing(CustomTestCase): ...@@ -190,7 +190,7 @@ class TestBenchServing(CustomTestCase):
f"### test_vlm_online_latency\n" f"### test_vlm_online_latency\n"
f'median_e2e_latency_ms: {res["median_e2e_latency_ms"]:.2f} ms\n' f'median_e2e_latency_ms: {res["median_e2e_latency_ms"]:.2f} ms\n'
) )
self.assertLess(res["median_e2e_latency_ms"], 16000) self.assertLess(res["median_e2e_latency_ms"], 16500)
if os.getenv("SGLANG_AMD_CI") == "1": if os.getenv("SGLANG_AMD_CI") == "1":
self.assertLess(res["median_ttft_ms"], 150) self.assertLess(res["median_ttft_ms"], 150)
# TODO: not set yet, need AMD machine # TODO: not set yet, need AMD machine
......
...@@ -3,7 +3,6 @@ Usage: ...@@ -3,7 +3,6 @@ Usage:
python3 test/srt/test_flashmla.py python3 test/srt/test_flashmla.py
""" """
import os
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
...@@ -61,7 +60,7 @@ class TestFlashMLAAttnBackend(unittest.TestCase): ...@@ -61,7 +60,7 @@ class TestFlashMLAAttnBackend(unittest.TestCase):
metrics = run_eval_few_shot_gsm8k(args) metrics = run_eval_few_shot_gsm8k(args)
print(metrics) print(metrics)
self.assertGreater(metrics["accuracy"], 0.62) self.assertGreater(metrics["accuracy"], 0.60)
class TestFlashMLAAttnLatency(unittest.TestCase): class TestFlashMLAAttnLatency(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment