Unverified Commit 8dbdc018 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Abort disconnected requests (#457)

parent 3e684be7
......@@ -580,8 +580,8 @@ class StreamExecutor:
def _execute_role_end(self, expr: SglRoleEnd):
if (
self.cur_role == "assistant"
and self.backend.is_chat_model
and self.api_num_spec_tokens is not None
and self.backend.is_chat_model
):
# Execute the stored lazy generation calls
self.backend.role_end_generate(self)
......
......@@ -19,6 +19,7 @@ class FinishReason(IntEnum):
EOS_TOKEN = auto()
LENGTH = auto()
STOP_STR = auto()
ABORT = auto()
@staticmethod
def to_str(reason):
......@@ -28,6 +29,8 @@ class FinishReason(IntEnum):
return "length"
elif reason == FinishReason.STOP_STR:
return "stop"
elif reason == FinishReason.ABORT:
return "abort"
else:
return None
......@@ -86,6 +89,35 @@ class Req:
def max_new_tokens(self):
return self.sampling_params.max_new_tokens
def check_finished(self):
if self.finished:
return
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
self.finished = True
self.finish_reason = FinishReason.LENGTH
return
if (
self.output_ids[-1] == self.tokenizer.eos_token_id
and self.sampling_params.ignore_eos == False
):
self.finished = True
self.finish_reason = FinishReason.EOS_TOKEN
return
if len(self.sampling_params.stop_strs) > 0:
tail_str = self.tokenizer.decode(
self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
)
for stop_str in self.sampling_params.stop_strs:
if stop_str in tail_str:
self.finished = True
self.finish_reason = FinishReason.STOP_STR
self.hit_stop_str = stop_str
return
def jump_forward_and_retokenize(self, jump_forward_str, next_state):
old_output_str = self.tokenizer.decode(self.output_ids)
# FIXME: This logic does not really solve the problem of determining whether
......@@ -132,35 +164,6 @@ class Req:
# print(f"Output and jump forward str:\n{self.output_and_jump_forward_str}")
# print("*" * 100)
def check_finished(self):
if self.finished:
return
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
self.finished = True
self.finish_reason = FinishReason.LENGTH
return
if (
self.output_ids[-1] == self.tokenizer.eos_token_id
and self.sampling_params.ignore_eos == False
):
self.finished = True
self.finish_reason = FinishReason.EOS_TOKEN
return
if len(self.sampling_params.stop_strs) > 0:
tail_str = self.tokenizer.decode(
self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
)
for stop_str in self.sampling_params.stop_strs:
if stop_str in tail_str:
self.finished = True
self.finish_reason = FinishReason.STOP_STR
self.hit_stop_str = stop_str
return
def __repr__(self):
return f"rid(n={self.rid}, " f"input_ids={self.input_ids}, "
......
......@@ -679,6 +679,7 @@ class ModelRpcServer:
)
def abort_request(self, recv_req):
# Delete requests in the waiting queue
to_del = None
for i, req in enumerate(self.forward_queue):
if req.rid == recv_req.rid:
......@@ -688,6 +689,14 @@ class ModelRpcServer:
if to_del is not None:
del self.forward_queue[to_del]
# Delete requests in the running batch
if self.running_batch:
for req in self.running_batch.reqs:
if req.rid == recv_req.rid:
req.finished = True
req.finish_reason = FinishReason.ABORT
break
class ModelRpcService(rpyc.Service):
exposed_ModelRpcServer = ModelRpcServer
......
......@@ -11,6 +11,7 @@ import transformers
import uvloop
import zmq
import zmq.asyncio
from fastapi import BackgroundTasks
from sglang.srt.hf_transformers_utils import (
get_config,
......@@ -165,7 +166,7 @@ class TokenizerManager:
while True:
try:
await asyncio.wait_for(event.wait(), timeout=5)
await asyncio.wait_for(event.wait(), timeout=4)
except asyncio.TimeoutError:
if request is not None and await request.is_disconnected():
self.abort_request(rid)
......@@ -243,7 +244,7 @@ class TokenizerManager:
while True:
try:
await asyncio.wait_for(state.event.wait(), timeout=5)
await asyncio.wait_for(state.event.wait(), timeout=4)
break
except asyncio.TimeoutError:
if request is not None and await request.is_disconnected():
......@@ -270,10 +271,26 @@ class TokenizerManager:
self.send_to_router.send_pyobj(req)
def abort_request(self, rid):
if rid not in self.rid_to_state:
return
del self.rid_to_state[rid]
req = AbortReq(rid)
self.send_to_router.send_pyobj(req)
def create_abort_task(self, obj):
# Abort the request if the client is disconnected.
async def abort_request():
await asyncio.sleep(3)
if obj.is_single:
self.abort_request(obj.rid)
else:
for rid in obj.rids:
self.abort_request(rid)
background_tasks = BackgroundTasks()
background_tasks.add_task(abort_request)
return background_tasks
def create_handle_loop(self):
self.to_create_loop = False
loop = asyncio.get_event_loop()
......
"""Conversion between OpenAI APIs and native SRT APIs"""
import asyncio
import json
import os
from http import HTTPStatus
from fastapi import HTTPException, Request
from fastapi.responses import StreamingResponse
from fastapi import Request
from fastapi.responses import StreamingResponse, JSONResponse
from sglang.srt.conversation import (
Conversation,
......@@ -27,14 +29,36 @@ from sglang.srt.openai_protocol import (
CompletionResponseStreamChoice,
CompletionStreamResponse,
DeltaMessage,
ErrorResponse,
LogProbs,
UsageInfo,
)
from sglang.srt.utils import jsonify_pydantic_model
chat_template_name = None
def create_error_response(
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST):
error = ErrorResponse(message=message,
type=err_type,
code=status_code.value)
return JSONResponse(content=error.model_dump(),
status_code=error.code)
def create_streaming_error_response(
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
error = ErrorResponse(message=message,
type=err_type,
code=status_code.value)
json_str = json.dumps({"error": error.model_dump()})
return json_str
def load_chat_template_for_openai_api(chat_template_arg):
global chat_template_name
......@@ -74,8 +98,8 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
request_json = await raw_request.json()
request = CompletionRequest(**request_json)
# TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
assert request.n == 1
if request.n != 1:
return create_error_response("n != 1 is not supported")
adapted_request = GenerateReqInput(
text=request.prompt,
......@@ -93,79 +117,88 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
return_text_in_logprobs=True,
stream=request.stream,
)
adapted_request.post_init()
if adapted_request.stream:
async def generate_stream_resp():
stream_buffer = ""
n_prev_token = 0
async for content in tokenizer_manager.generate_request(adapted_request):
text = content["text"]
prompt_tokens = content["meta_info"]["prompt_tokens"]
completion_tokens = content["meta_info"]["completion_tokens"]
if not stream_buffer: # The first chunk
if request.echo:
# Prepend prompt in response text.
text = request.prompt + text
if request.logprobs:
# The first chunk and echo is enabled.
if not stream_buffer and request.echo:
prefill_token_logprobs = content["meta_info"][
"prefill_token_logprobs"
]
prefill_top_logprobs = content["meta_info"][
"prefill_top_logprobs"
]
try:
async for content in tokenizer_manager.generate_request(
adapted_request, raw_request):
text = content["text"]
prompt_tokens = content["meta_info"]["prompt_tokens"]
completion_tokens = content["meta_info"]["completion_tokens"]
if not stream_buffer: # The first chunk
if request.echo:
# Prepend prompt in response text.
text = request.prompt + text
if request.logprobs:
# The first chunk and echo is enabled.
if not stream_buffer and request.echo:
prefill_token_logprobs = content["meta_info"][
"prefill_token_logprobs"
]
prefill_top_logprobs = content["meta_info"][
"prefill_top_logprobs"
]
else:
prefill_token_logprobs = None
prefill_top_logprobs = None
logprobs = to_openai_style_logprobs(
prefill_token_logprobs=prefill_token_logprobs,
prefill_top_logprobs=prefill_top_logprobs,
decode_token_logprobs=content["meta_info"][
"decode_token_logprobs"
][n_prev_token:],
decode_top_logprobs=content["meta_info"]["decode_top_logprobs"][
n_prev_token:
],
)
n_prev_token = len(content["meta_info"]["decode_token_logprobs"])
else:
prefill_token_logprobs = None
prefill_top_logprobs = None
logprobs = to_openai_style_logprobs(
prefill_token_logprobs=prefill_token_logprobs,
prefill_top_logprobs=prefill_top_logprobs,
decode_token_logprobs=content["meta_info"][
"decode_token_logprobs"
][n_prev_token:],
decode_top_logprobs=content["meta_info"]["decode_top_logprobs"][
n_prev_token:
],
)
logprobs = None
n_prev_token = len(content["meta_info"]["decode_token_logprobs"])
else:
logprobs = None
delta = text[len(stream_buffer) :]
stream_buffer = content["text"]
choice_data = CompletionResponseStreamChoice(
index=0,
text=delta,
logprobs=logprobs,
finish_reason=content["meta_info"]["finish_reason"],
)
chunk = CompletionStreamResponse(
id=content["meta_info"]["id"],
object="text_completion",
choices=[choice_data],
model=request.model,
usage=UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
delta = text[len(stream_buffer) :]
stream_buffer = content["text"]
choice_data = CompletionResponseStreamChoice(
index=0,
text=delta,
logprobs=logprobs,
finish_reason=content["meta_info"]["finish_reason"],
)
chunk = CompletionStreamResponse(
id=content["meta_info"]["id"],
object="text_completion",
choices=[choice_data],
model=request.model,
usage=UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
yield f"data: {chunk.model_dump_json()}\n\n"
except ValueError as e:
error = create_streaming_error_response(str(e))
yield f"data: {error}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(generate_stream_resp(), media_type="text/event-stream")
return StreamingResponse(generate_stream_resp(), media_type="text/event-stream",
background=tokenizer_manager.create_abort_task(adapted_request))
# Non-streaming response.
ret = await tokenizer_manager.generate_request(adapted_request).__anext__()
ret = ret[0] if isinstance(ret, list) else ret
try:
ret = await tokenizer_manager.generate_request(
adapted_request, raw_request).__anext__()
except ValueError as e:
return create_error_response(str(e))
ret = ret[0] if isinstance(ret, list) else ret
prompt_tokens = ret["meta_info"]["prompt_tokens"]
completion_tokens = ret["meta_info"]["completion_tokens"]
text = ret["text"]
......@@ -212,8 +245,8 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
request_json = await raw_request.json()
request = ChatCompletionRequest(**request_json)
# TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
assert request.n == 1
if request.n != 1:
return create_error_response("n != 1 is not supported")
# Prep the data needed for the underlying GenerateReqInput:
# - prompt: The full prompt string.
......@@ -258,7 +291,6 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
},
stream=request.stream,
)
adapted_request.post_init()
if adapted_request.stream:
......@@ -266,13 +298,29 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
is_first = True
stream_buffer = ""
async for content in tokenizer_manager.generate_request(adapted_request):
if is_first:
# First chunk with role
is_first = False
try:
async for content in tokenizer_manager.generate_request(adapted_request, raw_request):
if is_first:
# First chunk with role
is_first = False
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant"),
finish_reason=content["meta_info"]["finish_reason"],
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.model_dump_json()}\n\n"
text = content["text"]
delta = text[len(stream_buffer) :]
stream_buffer = text
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant"),
delta=DeltaMessage(content=delta),
finish_reason=content["meta_info"]["finish_reason"],
)
chunk = ChatCompletionStreamResponse(
......@@ -280,28 +328,22 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
choices=[choice_data],
model=request.model,
)
yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
text = content["text"]
delta = text[len(stream_buffer) :]
stream_buffer = text
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(content=delta),
finish_reason=content["meta_info"]["finish_reason"],
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
choices=[choice_data],
model=request.model,
)
yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
yield f"data: {chunk.model_dump_json()}\n\n"
except ValueError as e:
error = create_streaming_error_response(str(e))
yield f"data: {error}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(generate_stream_resp(), media_type="text/event-stream")
return StreamingResponse(generate_stream_resp(), media_type="text/event-stream",
background=tokenizer_manager.create_abort_task(adapted_request))
# Non-streaming response.
ret = await tokenizer_manager.generate_request(adapted_request).__anext__()
try:
ret = await tokenizer_manager.generate_request(
adapted_request, raw_request).__anext__()
except ValueError as e:
return create_error_response(str(e))
prompt_tokens = ret["meta_info"]["prompt_tokens"]
completion_tokens = ret["meta_info"]["completion_tokens"]
choice_data = ChatCompletionResponseChoice(
......
......@@ -7,6 +7,14 @@ from pydantic import BaseModel, Field
from typing_extensions import Literal
class ErrorResponse(BaseModel):
object: str = "error"
message: str
type: str
param: Optional[str] = None
code: int
class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
......
......@@ -93,7 +93,8 @@ async def generate_request(obj: GenerateReqInput, request: Request):
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(stream_results(), media_type="text/event-stream")
return StreamingResponse(stream_results(), media_type="text/event-stream",
background=tokenizer_manager.create_abort_task(obj))
else:
try:
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
......
......@@ -392,14 +392,4 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
content={"detail": "Invalid API Key"},
)
response = await call_next(request)
return response
# FIXME: Remove this once we drop support for pydantic 1.x
IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1
def jsonify_pydantic_model(obj: BaseModel):
if IS_PYDANTIC_1:
return obj.json(ensure_ascii=False)
return obj.model_dump_json()
return response
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment