Unverified Commit de167cf5 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix request abortion (#6184)

parent 4319978c
......@@ -56,7 +56,7 @@ jobs:
strategy:
fail-fast: false
matrix:
part: [0, 1, 2, 3, 4, 5, 6, 7]
part: [0, 1, 2, 3, 4, 5, 6, 7, 8]
steps:
- name: Checkout code
uses: actions/checkout@v4
......@@ -69,7 +69,7 @@ jobs:
timeout-minutes: 30
run: |
cd test/srt
python3 run_suite.py --suite per-commit --auto-partition-id ${{ matrix.part }} --auto-partition-size 8
python3 run_suite.py --suite per-commit --auto-partition-id ${{ matrix.part }} --auto-partition-size 9
unit-test-backend-2-gpu:
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
......
......@@ -49,6 +49,7 @@ from sglang.srt.disaggregation.utils import (
from sglang.srt.entrypoints.engine import _launch_subprocesses
from sglang.srt.function_call_parser import FunctionCallParser
from sglang.srt.managers.io_struct import (
AbortReq,
CloseSessionReqInput,
ConfigureLoggingReq,
EmbeddingReqInput,
......@@ -539,6 +540,16 @@ async def configure_logging(obj: ConfigureLoggingReq, request: Request):
return Response(status_code=200)
@app.post("/abort_request")
async def abort_request(obj: AbortReq, request: Request):
"""Abort a request."""
try:
_global_state.tokenizer_manager.abort_request(rid=obj.rid)
return Response(status_code=200)
except Exception as e:
return _create_error_response(e)
@app.post("/parse_function_call")
async def parse_function_call_request(obj: ParseFunctionCallReq, request: Request):
"""
......
from __future__ import annotations
import hashlib
from enum import Enum, auto
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -30,12 +27,16 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
It will be transformed from CPU scheduler to GPU model runner.
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
It contains low-level tensor data. Most of the data consists of GPU tensors.
TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing it in the future.
"""
import copy
import dataclasses
import hashlib
import logging
import threading
from enum import Enum, auto
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
import numpy as np
......@@ -134,9 +135,9 @@ class FINISH_LENGTH(BaseFinishReason):
class FINISH_ABORT(BaseFinishReason):
def __init__(self, message="Unknown error", status_code=None, err_type=None):
def __init__(self, message=None, status_code=None, err_type=None):
super().__init__(is_error=True)
self.message = message
self.message = message or "Aborted"
self.status_code = status_code
self.err_type = err_type
......@@ -441,11 +442,13 @@ class Req:
# Check finish
self.tokenizer = None
self.finished_reason = None
# Whether this request has finished output
self.finished_output = None
# If we want to abort the request in the middle of the event loop, set this to true
# Note: We should never set finished_reason in the middle, the req will get filtered and never respond
self.to_abort = False
# This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
self.to_abort_message: str = "Unknown error"
self.to_abort_message: str = None
self.stream = stream
self.eos_token_ids = eos_token_ids
......@@ -546,8 +549,6 @@ class Req:
self.bootstrap_room: Optional[int] = bootstrap_room
self.disagg_kv_sender: Optional[BaseKVSender] = None
# used for warmup because we don't have a pair yet when init
self.skip_kv_transfer: bool = False
# the start index of the sent kv cache
# We want to send it chunk by chunk for chunked prefill.
# After every chunk forward, we do the following:
......@@ -555,15 +556,15 @@ class Req:
# start_send_idx = len(req.fill_ids)
self.start_send_idx: int = 0
self.metadata_buffer_index: int = -1
# The first output_id transferred from prefill instance.
self.transferred_output_id: Optional[int] = None
# For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap
# This is because kv is not ready in `process_prefill_chunk`.
# We use `tmp_end_idx` to store the end index of the kv cache to send.
self.tmp_end_idx: int = -1
self.metadata_buffer_index: int = -1
# The first output_id transferred from prefill instance.
self.transferred_output_id: Optional[int] = None
@property
def seqlen(self):
return len(self.origin_input_ids) + len(self.output_ids)
......@@ -697,13 +698,29 @@ class Req:
self.req_pool_idx = None
self.already_computed = 0
def offload_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
token_indices = req_to_token_pool.req_to_token[
self.req_pool_idx, : self.seqlen - 1
]
self.kv_cache_cpu = token_to_kv_pool_allocator.get_cpu_copy(token_indices)
def load_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
token_indices = req_to_token_pool.req_to_token[
self.req_pool_idx, : self.seqlen - 1
]
token_to_kv_pool_allocator.load_cpu_copy(self.kv_cache_cpu, token_indices)
del self.kv_cache_cpu
def __repr__(self):
return (
f"Req(rid={self.rid}, "
f"input_ids={self.origin_input_ids}, output_ids={self.output_ids})"
f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}, "
f"{self.grammar=}, "
f"{self.sampling_params=})"
)
# Batch id
bid = 0
......@@ -1447,7 +1464,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
i
for i in range(len(self.reqs))
if not self.reqs[i].finished()
and not self.reqs[i] in chunked_req_to_exclude
and self.reqs[i] not in chunked_req_to_exclude
]
if keep_indices is None or len(keep_indices) == 0:
......
......@@ -20,7 +20,6 @@ import signal
import sys
import threading
import time
import warnings
from collections import defaultdict, deque
from concurrent import futures
from dataclasses import dataclass
......@@ -121,11 +120,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch,
ForwardMode,
PPProxyTensors,
)
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
from sglang.srt.reasoning_parser import ReasoningParser
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
......@@ -135,6 +130,7 @@ from sglang.srt.utils import (
broadcast_pyobj,
configure_logger,
crash_on_warnings,
disable_request_logging,
get_bool_env_var,
get_zmq_socket,
kill_itself_when_parent_died,
......@@ -907,19 +903,6 @@ class Scheduler(
fake_input_ids = [1] * seq_length
recv_req.input_ids = fake_input_ids
# Handle custom logit processor passed to the request
custom_logit_processor = recv_req.custom_logit_processor
if (
not self.server_args.enable_custom_logit_processor
and custom_logit_processor is not None
):
logger.warning(
"The SGLang server is not configured to enable custom logit processor."
"The custom logit processor passed in will be ignored."
"Please set --enable-custom-logits-processor to enable this feature."
)
custom_logit_processor = None
if recv_req.bootstrap_port is None:
# Use default bootstrap port
recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port
......@@ -935,7 +918,7 @@ class Scheduler(
stream=recv_req.stream,
lora_path=recv_req.lora_path,
input_embeds=recv_req.input_embeds,
custom_logit_processor=custom_logit_processor,
custom_logit_processor=recv_req.custom_logit_processor,
return_hidden_states=recv_req.return_hidden_states,
eos_token_ids=self.model_config.hf_eos_token_id,
bootstrap_host=recv_req.bootstrap_host,
......@@ -1246,9 +1229,7 @@ class Scheduler(
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
f"{self.tree_cache.evictable_size()=}\n"
)
warnings.warn(msg)
if crash_on_warnings():
raise ValueError(msg)
raise ValueError(msg)
if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
msg = (
......@@ -1256,9 +1237,7 @@ class Scheduler(
f"available_size={len(self.req_to_token_pool.free_slots)}, "
f"total_size={self.req_to_token_pool.size}\n"
)
warnings.warn(msg)
if crash_on_warnings():
raise ValueError(msg)
raise ValueError(msg)
if (
self.enable_metrics
......@@ -1774,24 +1753,27 @@ class Scheduler(
if self.cur_batch is not None:
if self.watchdog_last_forward_ct == self.forward_ct:
if current > self.watchdog_last_time + self.watchdog_timeout:
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
break
else:
self.watchdog_last_forward_ct = self.forward_ct
self.watchdog_last_time = current
time.sleep(self.watchdog_timeout // 2)
# Print batch size and memory pool info to check whether there are de-sync issues.
logger.error(
f"{self.cur_batch.batch_size()=}, "
f"{self.cur_batch.reqs=}, "
f"{self.token_to_kv_pool_allocator.available_size()=}, "
f"{self.tree_cache.evictable_size()=}, "
)
# Wait for some time so that the parent process can print the error.
if not disable_request_logging():
# Print batch size and memory pool info to check whether there are de-sync issues.
logger.error(
f"{self.cur_batch.batch_size()=}, "
f"{self.cur_batch.reqs=}, "
f"{self.token_to_kv_pool_allocator.available_size()=}, "
f"{self.tree_cache.evictable_size()=}, "
)
pyspy_dump_schedulers()
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
print(file=sys.stderr, flush=True)
print(file=sys.stdout, flush=True)
# Wait for some time so that the parent process can print the error.
time.sleep(5)
self.parent_process.send_signal(signal.SIGQUIT)
......@@ -1923,25 +1905,30 @@ class Scheduler(
)
def abort_request(self, recv_req: AbortReq):
# TODO(lmzheng): abort the requests in the grammar queue.
# Delete requests in the waiting queue
to_del = []
for i, req in enumerate(self.waiting_queue):
if req.rid.startswith(recv_req.rid):
to_del.append(i)
break
# Sort in reverse order to avoid index issues when deleting
for i in sorted(to_del, reverse=True):
for i in reversed(to_del):
req = self.waiting_queue.pop(i)
self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
logger.debug(f"Abort queued request. {req.rid=}")
return
# Delete requests in the running batch
for req in self.running_batch.reqs:
if self.cur_batch is self.running_batch or self.cur_batch is None:
reqs = self.running_batch.reqs
else:
reqs = self.running_batch.reqs + self.cur_batch.reqs
for req in reqs:
if req.rid.startswith(recv_req.rid) and not req.finished():
logger.debug(f"Abort running request. {req.rid=}")
req.to_abort = True
return
def _pause_engine(self) -> Tuple[List[Req], int]:
raise NotImplementedError()
......
......@@ -15,6 +15,8 @@ if TYPE_CHECKING:
Scheduler,
)
DEFAULT_FORCE_STREAM_INTERVAL = 50
class SchedulerOutputProcessorMixin:
"""
......@@ -512,19 +514,26 @@ class SchedulerOutputProcessorMixin:
if self.model_config.is_multimodal_gen and req.to_abort:
continue
if (
req.finished()
# If stream, follow the given stream_interval
or (req.stream and len(req.output_ids) % self.stream_interval == 0)
# If not stream, we still want to output some tokens to get the benefit of incremental decoding.
# TODO(lianmin): this is wrong for speculative decoding because len(req.output_ids) does not
# always increase one-by-one.
or (
not req.stream
and len(req.output_ids) % 50 == 0
and not self.model_config.is_multimodal_gen
)
):
if req.finished():
if req.finished_output:
# With the overlap schedule, a request will try to output twice and hit this line twice
# because of the one additional delayed token. This "continue" prevented the dummy output.
continue
req.finished_output = True
should_output = True
else:
if req.stream:
stream_interval = (
req.sampling_params.stream_interval or self.stream_interval
)
should_output = len(req.output_ids) % stream_interval == 0
else:
should_output = (
len(req.output_ids) % DEFAULT_FORCE_STREAM_INTERVAL == 0
and not self.model_config.is_multimodal_gen
)
if should_output:
rids.append(req.rid)
finished_reasons.append(
req.finished_reason.to_json() if req.finished_reason else None
......
......@@ -288,6 +288,7 @@ class TokenizerManager:
),
self._handle_batch_output,
),
(AbortReq, self._handle_abort_req),
(OpenSessionReqOutput, self._handle_open_session_req_output),
(
UpdateWeightFromDiskReqOutput,
......@@ -341,13 +342,14 @@ class TokenizerManager:
]
)
# For pd disaggregtion
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
self.transfer_backend = TransferBackend(
self.server_args.disaggregation_transfer_backend
)
# for disaggregtion, start kv boostrap server on prefill
# Start kv boostrap server on prefill
if self.disaggregation_mode == DisaggregationMode.PREFILL:
# only start bootstrap server on prefill tm
kv_bootstrap_server_class = get_kv_class(
......@@ -482,6 +484,14 @@ class TokenizerManager:
session_params = (
SessionParams(**obj.session_params) if obj.session_params else None
)
if (
obj.custom_logit_processor
and not self.server_args.enable_custom_logit_processor
):
raise ValueError(
"The server is not configured to enable custom logit processor. "
"Please set `--enable-custom-logits-processor` to enable this feature."
)
sampling_params = SamplingParams(**obj.sampling_params)
sampling_params.normalize(self.tokenizer)
......@@ -570,9 +580,9 @@ class TokenizerManager:
tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
created_time: Optional[float] = None,
):
self.send_to_scheduler.send_pyobj(tokenized_obj)
state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
self.rid_to_state[obj.rid] = state
self.send_to_scheduler.send_pyobj(tokenized_obj)
async def _wait_one_response(
self,
......@@ -587,10 +597,11 @@ class TokenizerManager:
await asyncio.wait_for(state.event.wait(), timeout=4)
except asyncio.TimeoutError:
if request is not None and await request.is_disconnected():
# Abort the request for disconnected requests (non-streaming, waiting queue)
self.abort_request(obj.rid)
# Use exception to kill the whole call stack and asyncio task
raise ValueError(
"Request is disconnected from the client side. "
f"Abort request {obj.rid}"
f"Request is disconnected from the client side (type 1). Abort request {obj.rid=}"
)
continue
......@@ -605,7 +616,6 @@ class TokenizerManager:
else:
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
logger.info(msg)
del self.rid_to_state[obj.rid]
# Check if this was an abort/error created by scheduler
if isinstance(out["meta_info"].get("finish_reason"), dict):
......@@ -625,10 +635,11 @@ class TokenizerManager:
yield out
else:
if request is not None and await request.is_disconnected():
# Abort the request for disconnected requests (non-streaming, running)
self.abort_request(obj.rid)
# Use exception to kill the whole call stack and asyncio task
raise ValueError(
"Request is disconnected from the client side. "
f"Abort request {obj.rid}"
f"Request is disconnected from the client side (type 3). Abort request {obj.rid=}"
)
async def _handle_batch_request(
......@@ -728,7 +739,6 @@ class TokenizerManager:
def abort_request(self, rid: str):
if rid not in self.rid_to_state:
return
del self.rid_to_state[rid]
req = AbortReq(rid)
self.send_to_scheduler.send_pyobj(req)
......@@ -964,7 +974,7 @@ class TokenizerManager:
def create_abort_task(self, obj: GenerateReqInput):
# Abort the request if the client is disconnected.
async def abort_request():
await asyncio.sleep(1)
await asyncio.sleep(2)
if obj.is_single:
self.abort_request(obj.rid)
else:
......@@ -1035,6 +1045,9 @@ class TokenizerManager:
for i, rid in enumerate(recv_obj.rids):
state = self.rid_to_state.get(rid, None)
if state is None:
logger.error(
f"Received output for {rid=} but the state was deleted in TokenizerManager."
)
continue
# Build meta_info and return value
......@@ -1098,6 +1111,7 @@ class TokenizerManager:
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
state.finished_time = time.time()
meta_info["e2e_latency"] = state.finished_time - state.created_time
del self.rid_to_state[rid]
state.out_list.append(out_dict)
state.event.set()
......@@ -1246,6 +1260,9 @@ class TokenizerManager:
# Schedule the task to run in the background without awaiting it
asyncio.create_task(asyncio.to_thread(background_task))
def _handle_abort_req(self, recv_obj):
self.rid_to_state.pop(recv_obj.rid)
def _handle_open_session_req_output(self, recv_obj):
self.session_futures[recv_obj.session_id].set_result(
recv_obj.session_id if recv_obj.success else None
......@@ -1325,3 +1342,15 @@ class _Communicator(Generic[T]):
self._result_values.append(recv_obj)
if len(self._result_values) == self._fan_out:
self._result_event.set()
# Note: request abort handling logic
# We should handle all of the following cases correctly.
#
# | entrypoint | is_streaming | status | abort engine | cancel asyncio task | rid_to_state |
# | ---------- | ------------ | --------------- | --------------- | --------------------- | --------------------------- |
# | http | yes | waiting queue | background task | fast api | del in _handle_abort_req |
# | http | yes | running | background task | fast api | del in _handle_batch_output |
# | http | no | waiting queue | type 1 | type 1 exception | del in _handle_abort_req |
# | http | no | running | type 3 | type 3 exception | del in _handle_batch_output |
#
......@@ -50,6 +50,7 @@ class SamplingParams:
spaces_between_special_tokens: bool = True,
no_stop_trim: bool = False,
custom_params: Optional[Dict[str, Any]] = None,
stream_interval: Optional[int] = None,
) -> None:
self.max_new_tokens = max_new_tokens
self.stop_strs = stop
......@@ -75,6 +76,7 @@ class SamplingParams:
self.spaces_between_special_tokens = spaces_between_special_tokens
self.no_stop_trim = no_stop_trim
self.custom_params = custom_params
self.stream_interval = stream_interval
# Process some special cases
if 0 <= self.temperature < _SAMPLING_EPS:
......
......@@ -27,6 +27,7 @@ class BenchArgs:
"Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:"
)
image: bool = False
many_images: bool = False
stream: bool = False
@staticmethod
......@@ -48,6 +49,7 @@ class BenchArgs:
parser.add_argument("--return-logprob", action="store_true")
parser.add_argument("--prompt", type=str, default=BenchArgs.prompt)
parser.add_argument("--image", action="store_true")
parser.add_argument("--many-images", action="store_true")
parser.add_argument("--stream", action="store_true")
@classmethod
......@@ -62,6 +64,17 @@ def send_one_prompt(args):
"Human: Describe this image in a very short sentence.\n\nAssistant:"
)
image_data = "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png"
elif args.many_images:
args.prompt = (
"Human: I have one reference image and many images."
"Describe their relationship in a very short sentence.\n\nAssistant:"
)
image_data = [
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png",
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png",
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png",
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png",
]
else:
image_data = None
......@@ -74,9 +87,6 @@ def send_one_prompt(args):
"Write in a format of json.\nAssistant:"
)
json_schema = "$$ANY$$"
json_schema = (
'{"type": "object", "properties": {"population": {"type": "integer"}}}'
)
else:
json_schema = None
......
......@@ -190,7 +190,7 @@ class TestBenchServing(CustomTestCase):
f"### test_vlm_online_latency\n"
f'median_e2e_latency_ms: {res["median_e2e_latency_ms"]:.2f} ms\n'
)
self.assertLess(res["median_e2e_latency_ms"], 16000)
self.assertLess(res["median_e2e_latency_ms"], 16500)
if os.getenv("SGLANG_AMD_CI") == "1":
self.assertLess(res["median_ttft_ms"], 150)
# TODO: not set yet, need AMD machine
......
......@@ -3,7 +3,6 @@ Usage:
python3 test/srt/test_flashmla.py
"""
import os
import unittest
from types import SimpleNamespace
......@@ -61,7 +60,7 @@ class TestFlashMLAAttnBackend(unittest.TestCase):
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)
self.assertGreater(metrics["accuracy"], 0.62)
self.assertGreater(metrics["accuracy"], 0.60)
class TestFlashMLAAttnLatency(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment