"src/vscode:/vscode.git/clone" did not exist on "15e3ff878c3d8927b6f6fac702e4f74eaee7607a"
Unverified Commit 8dbdc018 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Abort disconnected requests (#457)

parent 3e684be7
...@@ -580,8 +580,8 @@ class StreamExecutor: ...@@ -580,8 +580,8 @@ class StreamExecutor:
def _execute_role_end(self, expr: SglRoleEnd): def _execute_role_end(self, expr: SglRoleEnd):
if ( if (
self.cur_role == "assistant" self.cur_role == "assistant"
and self.backend.is_chat_model
and self.api_num_spec_tokens is not None and self.api_num_spec_tokens is not None
and self.backend.is_chat_model
): ):
# Execute the stored lazy generation calls # Execute the stored lazy generation calls
self.backend.role_end_generate(self) self.backend.role_end_generate(self)
......
...@@ -19,6 +19,7 @@ class FinishReason(IntEnum): ...@@ -19,6 +19,7 @@ class FinishReason(IntEnum):
EOS_TOKEN = auto() EOS_TOKEN = auto()
LENGTH = auto() LENGTH = auto()
STOP_STR = auto() STOP_STR = auto()
ABORT = auto()
@staticmethod @staticmethod
def to_str(reason): def to_str(reason):
...@@ -28,6 +29,8 @@ class FinishReason(IntEnum): ...@@ -28,6 +29,8 @@ class FinishReason(IntEnum):
return "length" return "length"
elif reason == FinishReason.STOP_STR: elif reason == FinishReason.STOP_STR:
return "stop" return "stop"
elif reason == FinishReason.ABORT:
return "abort"
else: else:
return None return None
...@@ -86,6 +89,35 @@ class Req: ...@@ -86,6 +89,35 @@ class Req:
def max_new_tokens(self): def max_new_tokens(self):
return self.sampling_params.max_new_tokens return self.sampling_params.max_new_tokens
def check_finished(self):
if self.finished:
return
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
self.finished = True
self.finish_reason = FinishReason.LENGTH
return
if (
self.output_ids[-1] == self.tokenizer.eos_token_id
and self.sampling_params.ignore_eos == False
):
self.finished = True
self.finish_reason = FinishReason.EOS_TOKEN
return
if len(self.sampling_params.stop_strs) > 0:
tail_str = self.tokenizer.decode(
self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
)
for stop_str in self.sampling_params.stop_strs:
if stop_str in tail_str:
self.finished = True
self.finish_reason = FinishReason.STOP_STR
self.hit_stop_str = stop_str
return
def jump_forward_and_retokenize(self, jump_forward_str, next_state): def jump_forward_and_retokenize(self, jump_forward_str, next_state):
old_output_str = self.tokenizer.decode(self.output_ids) old_output_str = self.tokenizer.decode(self.output_ids)
# FIXME: This logic does not really solve the problem of determining whether # FIXME: This logic does not really solve the problem of determining whether
...@@ -132,35 +164,6 @@ class Req: ...@@ -132,35 +164,6 @@ class Req:
# print(f"Output and jump forward str:\n{self.output_and_jump_forward_str}") # print(f"Output and jump forward str:\n{self.output_and_jump_forward_str}")
# print("*" * 100) # print("*" * 100)
def check_finished(self):
if self.finished:
return
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
self.finished = True
self.finish_reason = FinishReason.LENGTH
return
if (
self.output_ids[-1] == self.tokenizer.eos_token_id
and self.sampling_params.ignore_eos == False
):
self.finished = True
self.finish_reason = FinishReason.EOS_TOKEN
return
if len(self.sampling_params.stop_strs) > 0:
tail_str = self.tokenizer.decode(
self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
)
for stop_str in self.sampling_params.stop_strs:
if stop_str in tail_str:
self.finished = True
self.finish_reason = FinishReason.STOP_STR
self.hit_stop_str = stop_str
return
def __repr__(self): def __repr__(self):
return f"rid(n={self.rid}, " f"input_ids={self.input_ids}, " return f"rid(n={self.rid}, " f"input_ids={self.input_ids}, "
......
...@@ -679,6 +679,7 @@ class ModelRpcServer: ...@@ -679,6 +679,7 @@ class ModelRpcServer:
) )
def abort_request(self, recv_req): def abort_request(self, recv_req):
# Delete requests in the waiting queue
to_del = None to_del = None
for i, req in enumerate(self.forward_queue): for i, req in enumerate(self.forward_queue):
if req.rid == recv_req.rid: if req.rid == recv_req.rid:
...@@ -688,6 +689,14 @@ class ModelRpcServer: ...@@ -688,6 +689,14 @@ class ModelRpcServer:
if to_del is not None: if to_del is not None:
del self.forward_queue[to_del] del self.forward_queue[to_del]
# Delete requests in the running batch
if self.running_batch:
for req in self.running_batch.reqs:
if req.rid == recv_req.rid:
req.finished = True
req.finish_reason = FinishReason.ABORT
break
class ModelRpcService(rpyc.Service): class ModelRpcService(rpyc.Service):
exposed_ModelRpcServer = ModelRpcServer exposed_ModelRpcServer = ModelRpcServer
......
...@@ -11,6 +11,7 @@ import transformers ...@@ -11,6 +11,7 @@ import transformers
import uvloop import uvloop
import zmq import zmq
import zmq.asyncio import zmq.asyncio
from fastapi import BackgroundTasks
from sglang.srt.hf_transformers_utils import ( from sglang.srt.hf_transformers_utils import (
get_config, get_config,
...@@ -165,7 +166,7 @@ class TokenizerManager: ...@@ -165,7 +166,7 @@ class TokenizerManager:
while True: while True:
try: try:
await asyncio.wait_for(event.wait(), timeout=5) await asyncio.wait_for(event.wait(), timeout=4)
except asyncio.TimeoutError: except asyncio.TimeoutError:
if request is not None and await request.is_disconnected(): if request is not None and await request.is_disconnected():
self.abort_request(rid) self.abort_request(rid)
...@@ -243,7 +244,7 @@ class TokenizerManager: ...@@ -243,7 +244,7 @@ class TokenizerManager:
while True: while True:
try: try:
await asyncio.wait_for(state.event.wait(), timeout=5) await asyncio.wait_for(state.event.wait(), timeout=4)
break break
except asyncio.TimeoutError: except asyncio.TimeoutError:
if request is not None and await request.is_disconnected(): if request is not None and await request.is_disconnected():
...@@ -270,10 +271,26 @@ class TokenizerManager: ...@@ -270,10 +271,26 @@ class TokenizerManager:
self.send_to_router.send_pyobj(req) self.send_to_router.send_pyobj(req)
def abort_request(self, rid): def abort_request(self, rid):
if rid not in self.rid_to_state:
return
del self.rid_to_state[rid] del self.rid_to_state[rid]
req = AbortReq(rid) req = AbortReq(rid)
self.send_to_router.send_pyobj(req) self.send_to_router.send_pyobj(req)
def create_abort_task(self, obj):
# Abort the request if the client is disconnected.
async def abort_request():
await asyncio.sleep(3)
if obj.is_single:
self.abort_request(obj.rid)
else:
for rid in obj.rids:
self.abort_request(rid)
background_tasks = BackgroundTasks()
background_tasks.add_task(abort_request)
return background_tasks
def create_handle_loop(self): def create_handle_loop(self):
self.to_create_loop = False self.to_create_loop = False
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
......
"""Conversion between OpenAI APIs and native SRT APIs""" """Conversion between OpenAI APIs and native SRT APIs"""
import asyncio
import json import json
import os import os
from http import HTTPStatus
from fastapi import HTTPException, Request from fastapi import Request
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse, JSONResponse
from sglang.srt.conversation import ( from sglang.srt.conversation import (
Conversation, Conversation,
...@@ -27,14 +29,36 @@ from sglang.srt.openai_protocol import ( ...@@ -27,14 +29,36 @@ from sglang.srt.openai_protocol import (
CompletionResponseStreamChoice, CompletionResponseStreamChoice,
CompletionStreamResponse, CompletionStreamResponse,
DeltaMessage, DeltaMessage,
ErrorResponse,
LogProbs, LogProbs,
UsageInfo, UsageInfo,
) )
from sglang.srt.utils import jsonify_pydantic_model
chat_template_name = None chat_template_name = None
def create_error_response(
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST):
error = ErrorResponse(message=message,
type=err_type,
code=status_code.value)
return JSONResponse(content=error.model_dump(),
status_code=error.code)
def create_streaming_error_response(
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
error = ErrorResponse(message=message,
type=err_type,
code=status_code.value)
json_str = json.dumps({"error": error.model_dump()})
return json_str
def load_chat_template_for_openai_api(chat_template_arg): def load_chat_template_for_openai_api(chat_template_arg):
global chat_template_name global chat_template_name
...@@ -74,8 +98,8 @@ async def v1_completions(tokenizer_manager, raw_request: Request): ...@@ -74,8 +98,8 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
request_json = await raw_request.json() request_json = await raw_request.json()
request = CompletionRequest(**request_json) request = CompletionRequest(**request_json)
# TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid. if request.n != 1:
assert request.n == 1 return create_error_response("n != 1 is not supported")
adapted_request = GenerateReqInput( adapted_request = GenerateReqInput(
text=request.prompt, text=request.prompt,
...@@ -93,79 +117,88 @@ async def v1_completions(tokenizer_manager, raw_request: Request): ...@@ -93,79 +117,88 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
return_text_in_logprobs=True, return_text_in_logprobs=True,
stream=request.stream, stream=request.stream,
) )
adapted_request.post_init()
if adapted_request.stream: if adapted_request.stream:
async def generate_stream_resp(): async def generate_stream_resp():
stream_buffer = "" stream_buffer = ""
n_prev_token = 0 n_prev_token = 0
async for content in tokenizer_manager.generate_request(adapted_request): try:
text = content["text"] async for content in tokenizer_manager.generate_request(
prompt_tokens = content["meta_info"]["prompt_tokens"] adapted_request, raw_request):
completion_tokens = content["meta_info"]["completion_tokens"] text = content["text"]
prompt_tokens = content["meta_info"]["prompt_tokens"]
if not stream_buffer: # The first chunk completion_tokens = content["meta_info"]["completion_tokens"]
if request.echo:
# Prepend prompt in response text. if not stream_buffer: # The first chunk
text = request.prompt + text if request.echo:
# Prepend prompt in response text.
if request.logprobs: text = request.prompt + text
# The first chunk and echo is enabled.
if not stream_buffer and request.echo: if request.logprobs:
prefill_token_logprobs = content["meta_info"][ # The first chunk and echo is enabled.
"prefill_token_logprobs" if not stream_buffer and request.echo:
] prefill_token_logprobs = content["meta_info"][
prefill_top_logprobs = content["meta_info"][ "prefill_token_logprobs"
"prefill_top_logprobs" ]
] prefill_top_logprobs = content["meta_info"][
"prefill_top_logprobs"
]
else:
prefill_token_logprobs = None
prefill_top_logprobs = None
logprobs = to_openai_style_logprobs(
prefill_token_logprobs=prefill_token_logprobs,
prefill_top_logprobs=prefill_top_logprobs,
decode_token_logprobs=content["meta_info"][
"decode_token_logprobs"
][n_prev_token:],
decode_top_logprobs=content["meta_info"]["decode_top_logprobs"][
n_prev_token:
],
)
n_prev_token = len(content["meta_info"]["decode_token_logprobs"])
else: else:
prefill_token_logprobs = None logprobs = None
prefill_top_logprobs = None
logprobs = to_openai_style_logprobs(
prefill_token_logprobs=prefill_token_logprobs,
prefill_top_logprobs=prefill_top_logprobs,
decode_token_logprobs=content["meta_info"][
"decode_token_logprobs"
][n_prev_token:],
decode_top_logprobs=content["meta_info"]["decode_top_logprobs"][
n_prev_token:
],
)
n_prev_token = len(content["meta_info"]["decode_token_logprobs"]) delta = text[len(stream_buffer) :]
else: stream_buffer = content["text"]
logprobs = None choice_data = CompletionResponseStreamChoice(
index=0,
delta = text[len(stream_buffer) :] text=delta,
stream_buffer = content["text"] logprobs=logprobs,
choice_data = CompletionResponseStreamChoice( finish_reason=content["meta_info"]["finish_reason"],
index=0, )
text=delta, chunk = CompletionStreamResponse(
logprobs=logprobs, id=content["meta_info"]["id"],
finish_reason=content["meta_info"]["finish_reason"], object="text_completion",
) choices=[choice_data],
chunk = CompletionStreamResponse( model=request.model,
id=content["meta_info"]["id"], usage=UsageInfo(
object="text_completion", prompt_tokens=prompt_tokens,
choices=[choice_data], completion_tokens=completion_tokens,
model=request.model, total_tokens=prompt_tokens + completion_tokens,
usage=UsageInfo( ),
prompt_tokens=prompt_tokens, )
completion_tokens=completion_tokens, yield f"data: {chunk.model_dump_json()}\n\n"
total_tokens=prompt_tokens + completion_tokens, except ValueError as e:
), error = create_streaming_error_response(str(e))
) yield f"data: {error}\n\n"
yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
return StreamingResponse(generate_stream_resp(), media_type="text/event-stream") return StreamingResponse(generate_stream_resp(), media_type="text/event-stream",
background=tokenizer_manager.create_abort_task(adapted_request))
# Non-streaming response. # Non-streaming response.
ret = await tokenizer_manager.generate_request(adapted_request).__anext__() try:
ret = ret[0] if isinstance(ret, list) else ret ret = await tokenizer_manager.generate_request(
adapted_request, raw_request).__anext__()
except ValueError as e:
return create_error_response(str(e))
ret = ret[0] if isinstance(ret, list) else ret
prompt_tokens = ret["meta_info"]["prompt_tokens"] prompt_tokens = ret["meta_info"]["prompt_tokens"]
completion_tokens = ret["meta_info"]["completion_tokens"] completion_tokens = ret["meta_info"]["completion_tokens"]
text = ret["text"] text = ret["text"]
...@@ -212,8 +245,8 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): ...@@ -212,8 +245,8 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
request_json = await raw_request.json() request_json = await raw_request.json()
request = ChatCompletionRequest(**request_json) request = ChatCompletionRequest(**request_json)
# TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid. if request.n != 1:
assert request.n == 1 return create_error_response("n != 1 is not supported")
# Prep the data needed for the underlying GenerateReqInput: # Prep the data needed for the underlying GenerateReqInput:
# - prompt: The full prompt string. # - prompt: The full prompt string.
...@@ -258,7 +291,6 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): ...@@ -258,7 +291,6 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
}, },
stream=request.stream, stream=request.stream,
) )
adapted_request.post_init()
if adapted_request.stream: if adapted_request.stream:
...@@ -266,13 +298,29 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): ...@@ -266,13 +298,29 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
is_first = True is_first = True
stream_buffer = "" stream_buffer = ""
async for content in tokenizer_manager.generate_request(adapted_request): try:
if is_first: async for content in tokenizer_manager.generate_request(adapted_request, raw_request):
# First chunk with role if is_first:
is_first = False # First chunk with role
is_first = False
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant"),
finish_reason=content["meta_info"]["finish_reason"],
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.model_dump_json()}\n\n"
text = content["text"]
delta = text[len(stream_buffer) :]
stream_buffer = text
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=0, index=0,
delta=DeltaMessage(role="assistant"), delta=DeltaMessage(content=delta),
finish_reason=content["meta_info"]["finish_reason"], finish_reason=content["meta_info"]["finish_reason"],
) )
chunk = ChatCompletionStreamResponse( chunk = ChatCompletionStreamResponse(
...@@ -280,28 +328,22 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): ...@@ -280,28 +328,22 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
choices=[choice_data], choices=[choice_data],
model=request.model, model=request.model,
) )
yield f"data: {jsonify_pydantic_model(chunk)}\n\n" yield f"data: {chunk.model_dump_json()}\n\n"
except ValueError as e:
text = content["text"] error = create_streaming_error_response(str(e))
delta = text[len(stream_buffer) :] yield f"data: {error}\n\n"
stream_buffer = text
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(content=delta),
finish_reason=content["meta_info"]["finish_reason"],
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
choices=[choice_data],
model=request.model,
)
yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
return StreamingResponse(generate_stream_resp(), media_type="text/event-stream") return StreamingResponse(generate_stream_resp(), media_type="text/event-stream",
background=tokenizer_manager.create_abort_task(adapted_request))
# Non-streaming response. # Non-streaming response.
ret = await tokenizer_manager.generate_request(adapted_request).__anext__() try:
ret = await tokenizer_manager.generate_request(
adapted_request, raw_request).__anext__()
except ValueError as e:
return create_error_response(str(e))
prompt_tokens = ret["meta_info"]["prompt_tokens"] prompt_tokens = ret["meta_info"]["prompt_tokens"]
completion_tokens = ret["meta_info"]["completion_tokens"] completion_tokens = ret["meta_info"]["completion_tokens"]
choice_data = ChatCompletionResponseChoice( choice_data = ChatCompletionResponseChoice(
......
...@@ -7,6 +7,14 @@ from pydantic import BaseModel, Field ...@@ -7,6 +7,14 @@ from pydantic import BaseModel, Field
from typing_extensions import Literal from typing_extensions import Literal
class ErrorResponse(BaseModel):
object: str = "error"
message: str
type: str
param: Optional[str] = None
code: int
class LogProbs(BaseModel): class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list) text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list) token_logprobs: List[Optional[float]] = Field(default_factory=list)
......
...@@ -93,7 +93,8 @@ async def generate_request(obj: GenerateReqInput, request: Request): ...@@ -93,7 +93,8 @@ async def generate_request(obj: GenerateReqInput, request: Request):
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n" yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
return StreamingResponse(stream_results(), media_type="text/event-stream") return StreamingResponse(stream_results(), media_type="text/event-stream",
background=tokenizer_manager.create_abort_task(obj))
else: else:
try: try:
ret = await tokenizer_manager.generate_request(obj, request).__anext__() ret = await tokenizer_manager.generate_request(obj, request).__anext__()
......
...@@ -392,14 +392,4 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware): ...@@ -392,14 +392,4 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
content={"detail": "Invalid API Key"}, content={"detail": "Invalid API Key"},
) )
response = await call_next(request) response = await call_next(request)
return response return response
\ No newline at end of file
# FIXME: Remove this once we drop support for pydantic 1.x
IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1
def jsonify_pydantic_model(obj: BaseModel):
if IS_PYDANTIC_1:
return obj.json(ensure_ascii=False)
return obj.model_dump_json()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment