Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8dbdc018
Unverified
Commit
8dbdc018
authored
May 20, 2024
by
Lianmin Zheng
Committed by
GitHub
May 20, 2024
Browse files
Abort disconnected requests (#457)
parent
3e684be7
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
205 additions
and
135 deletions
+205
-135
python/sglang/lang/interpreter.py
python/sglang/lang/interpreter.py
+1
-1
python/sglang/srt/managers/router/infer_batch.py
python/sglang/srt/managers/router/infer_batch.py
+32
-29
python/sglang/srt/managers/router/model_rpc.py
python/sglang/srt/managers/router/model_rpc.py
+9
-0
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+19
-2
python/sglang/srt/openai_api_adapter.py
python/sglang/srt/openai_api_adapter.py
+133
-91
python/sglang/srt/openai_protocol.py
python/sglang/srt/openai_protocol.py
+8
-0
python/sglang/srt/server.py
python/sglang/srt/server.py
+2
-1
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+1
-11
No files found.
python/sglang/lang/interpreter.py
View file @
8dbdc018
...
...
@@ -580,8 +580,8 @@ class StreamExecutor:
def
_execute_role_end
(
self
,
expr
:
SglRoleEnd
):
if
(
self
.
cur_role
==
"assistant"
and
self
.
backend
.
is_chat_model
and
self
.
api_num_spec_tokens
is
not
None
and
self
.
backend
.
is_chat_model
):
# Execute the stored lazy generation calls
self
.
backend
.
role_end_generate
(
self
)
...
...
python/sglang/srt/managers/router/infer_batch.py
View file @
8dbdc018
...
...
@@ -19,6 +19,7 @@ class FinishReason(IntEnum):
EOS_TOKEN
=
auto
()
LENGTH
=
auto
()
STOP_STR
=
auto
()
ABORT
=
auto
()
@
staticmethod
def
to_str
(
reason
):
...
...
@@ -28,6 +29,8 @@ class FinishReason(IntEnum):
return
"length"
elif
reason
==
FinishReason
.
STOP_STR
:
return
"stop"
elif
reason
==
FinishReason
.
ABORT
:
return
"abort"
else
:
return
None
...
...
@@ -86,6 +89,35 @@ class Req:
def
max_new_tokens
(
self
):
return
self
.
sampling_params
.
max_new_tokens
def
check_finished
(
self
):
if
self
.
finished
:
return
if
len
(
self
.
output_ids
)
>=
self
.
sampling_params
.
max_new_tokens
:
self
.
finished
=
True
self
.
finish_reason
=
FinishReason
.
LENGTH
return
if
(
self
.
output_ids
[
-
1
]
==
self
.
tokenizer
.
eos_token_id
and
self
.
sampling_params
.
ignore_eos
==
False
):
self
.
finished
=
True
self
.
finish_reason
=
FinishReason
.
EOS_TOKEN
return
if
len
(
self
.
sampling_params
.
stop_strs
)
>
0
:
tail_str
=
self
.
tokenizer
.
decode
(
self
.
output_ids
[
-
(
self
.
sampling_params
.
stop_str_max_len
+
1
)
:]
)
for
stop_str
in
self
.
sampling_params
.
stop_strs
:
if
stop_str
in
tail_str
:
self
.
finished
=
True
self
.
finish_reason
=
FinishReason
.
STOP_STR
self
.
hit_stop_str
=
stop_str
return
def
jump_forward_and_retokenize
(
self
,
jump_forward_str
,
next_state
):
old_output_str
=
self
.
tokenizer
.
decode
(
self
.
output_ids
)
# FIXME: This logic does not really solve the problem of determining whether
...
...
@@ -132,35 +164,6 @@ class Req:
# print(f"Output and jump forward str:\n{self.output_and_jump_forward_str}")
# print("*" * 100)
def
check_finished
(
self
):
if
self
.
finished
:
return
if
len
(
self
.
output_ids
)
>=
self
.
sampling_params
.
max_new_tokens
:
self
.
finished
=
True
self
.
finish_reason
=
FinishReason
.
LENGTH
return
if
(
self
.
output_ids
[
-
1
]
==
self
.
tokenizer
.
eos_token_id
and
self
.
sampling_params
.
ignore_eos
==
False
):
self
.
finished
=
True
self
.
finish_reason
=
FinishReason
.
EOS_TOKEN
return
if
len
(
self
.
sampling_params
.
stop_strs
)
>
0
:
tail_str
=
self
.
tokenizer
.
decode
(
self
.
output_ids
[
-
(
self
.
sampling_params
.
stop_str_max_len
+
1
)
:]
)
for
stop_str
in
self
.
sampling_params
.
stop_strs
:
if
stop_str
in
tail_str
:
self
.
finished
=
True
self
.
finish_reason
=
FinishReason
.
STOP_STR
self
.
hit_stop_str
=
stop_str
return
def
__repr__
(
self
):
return
f
"rid(n=
{
self
.
rid
}
, "
f
"input_ids=
{
self
.
input_ids
}
, "
...
...
python/sglang/srt/managers/router/model_rpc.py
View file @
8dbdc018
...
...
@@ -679,6 +679,7 @@ class ModelRpcServer:
)
def
abort_request
(
self
,
recv_req
):
# Delete requests in the waiting queue
to_del
=
None
for
i
,
req
in
enumerate
(
self
.
forward_queue
):
if
req
.
rid
==
recv_req
.
rid
:
...
...
@@ -688,6 +689,14 @@ class ModelRpcServer:
if
to_del
is
not
None
:
del
self
.
forward_queue
[
to_del
]
# Delete requests in the running batch
if
self
.
running_batch
:
for
req
in
self
.
running_batch
.
reqs
:
if
req
.
rid
==
recv_req
.
rid
:
req
.
finished
=
True
req
.
finish_reason
=
FinishReason
.
ABORT
break
class
ModelRpcService
(
rpyc
.
Service
):
exposed_ModelRpcServer
=
ModelRpcServer
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
8dbdc018
...
...
@@ -11,6 +11,7 @@ import transformers
import
uvloop
import
zmq
import
zmq.asyncio
from
fastapi
import
BackgroundTasks
from
sglang.srt.hf_transformers_utils
import
(
get_config
,
...
...
@@ -165,7 +166,7 @@ class TokenizerManager:
while
True
:
try
:
await
asyncio
.
wait_for
(
event
.
wait
(),
timeout
=
5
)
await
asyncio
.
wait_for
(
event
.
wait
(),
timeout
=
4
)
except
asyncio
.
TimeoutError
:
if
request
is
not
None
and
await
request
.
is_disconnected
():
self
.
abort_request
(
rid
)
...
...
@@ -243,7 +244,7 @@ class TokenizerManager:
while
True
:
try
:
await
asyncio
.
wait_for
(
state
.
event
.
wait
(),
timeout
=
5
)
await
asyncio
.
wait_for
(
state
.
event
.
wait
(),
timeout
=
4
)
break
except
asyncio
.
TimeoutError
:
if
request
is
not
None
and
await
request
.
is_disconnected
():
...
...
@@ -270,10 +271,26 @@ class TokenizerManager:
self
.
send_to_router
.
send_pyobj
(
req
)
def
abort_request
(
self
,
rid
):
if
rid
not
in
self
.
rid_to_state
:
return
del
self
.
rid_to_state
[
rid
]
req
=
AbortReq
(
rid
)
self
.
send_to_router
.
send_pyobj
(
req
)
def
create_abort_task
(
self
,
obj
):
# Abort the request if the client is disconnected.
async
def
abort_request
():
await
asyncio
.
sleep
(
3
)
if
obj
.
is_single
:
self
.
abort_request
(
obj
.
rid
)
else
:
for
rid
in
obj
.
rids
:
self
.
abort_request
(
rid
)
background_tasks
=
BackgroundTasks
()
background_tasks
.
add_task
(
abort_request
)
return
background_tasks
def
create_handle_loop
(
self
):
self
.
to_create_loop
=
False
loop
=
asyncio
.
get_event_loop
()
...
...
python/sglang/srt/openai_api_adapter.py
View file @
8dbdc018
"""Conversion between OpenAI APIs and native SRT APIs"""
import
asyncio
import
json
import
os
from
http
import
HTTPStatus
from
fastapi
import
HTTPException
,
Request
from
fastapi.responses
import
StreamingResponse
from
fastapi
import
Request
from
fastapi.responses
import
StreamingResponse
,
JSONResponse
from
sglang.srt.conversation
import
(
Conversation
,
...
...
@@ -27,14 +29,36 @@ from sglang.srt.openai_protocol import (
CompletionResponseStreamChoice
,
CompletionStreamResponse
,
DeltaMessage
,
ErrorResponse
,
LogProbs
,
UsageInfo
,
)
from
sglang.srt.utils
import
jsonify_pydantic_model
chat_template_name
=
None
def
create_error_response
(
message
:
str
,
err_type
:
str
=
"BadRequestError"
,
status_code
:
HTTPStatus
=
HTTPStatus
.
BAD_REQUEST
):
error
=
ErrorResponse
(
message
=
message
,
type
=
err_type
,
code
=
status_code
.
value
)
return
JSONResponse
(
content
=
error
.
model_dump
(),
status_code
=
error
.
code
)
def
create_streaming_error_response
(
message
:
str
,
err_type
:
str
=
"BadRequestError"
,
status_code
:
HTTPStatus
=
HTTPStatus
.
BAD_REQUEST
)
->
str
:
error
=
ErrorResponse
(
message
=
message
,
type
=
err_type
,
code
=
status_code
.
value
)
json_str
=
json
.
dumps
({
"error"
:
error
.
model_dump
()})
return
json_str
def
load_chat_template_for_openai_api
(
chat_template_arg
):
global
chat_template_name
...
...
@@ -74,8 +98,8 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
request_json
=
await
raw_request
.
json
()
request
=
CompletionRequest
(
**
request_json
)
# TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
assert
request
.
n
=
=
1
if
request
.
n
!=
1
:
return
create_error_response
(
"
n
!
= 1
is not supported"
)
adapted_request
=
GenerateReqInput
(
text
=
request
.
prompt
,
...
...
@@ -93,79 +117,88 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
return_text_in_logprobs
=
True
,
stream
=
request
.
stream
,
)
adapted_request
.
post_init
()
if
adapted_request
.
stream
:
async
def
generate_stream_resp
():
stream_buffer
=
""
n_prev_token
=
0
async
for
content
in
tokenizer_manager
.
generate_request
(
adapted_request
):
text
=
content
[
"text"
]
prompt_tokens
=
content
[
"meta_info"
][
"prompt_tokens"
]
completion_tokens
=
content
[
"meta_info"
][
"completion_tokens"
]
if
not
stream_buffer
:
# The first chunk
if
request
.
echo
:
# Prepend prompt in response text.
text
=
request
.
prompt
+
text
if
request
.
logprobs
:
# The first chunk and echo is enabled.
if
not
stream_buffer
and
request
.
echo
:
prefill_token_logprobs
=
content
[
"meta_info"
][
"prefill_token_logprobs"
]
prefill_top_logprobs
=
content
[
"meta_info"
][
"prefill_top_logprobs"
]
try
:
async
for
content
in
tokenizer_manager
.
generate_request
(
adapted_request
,
raw_request
):
text
=
content
[
"text"
]
prompt_tokens
=
content
[
"meta_info"
][
"prompt_tokens"
]
completion_tokens
=
content
[
"meta_info"
][
"completion_tokens"
]
if
not
stream_buffer
:
# The first chunk
if
request
.
echo
:
# Prepend prompt in response text.
text
=
request
.
prompt
+
text
if
request
.
logprobs
:
# The first chunk and echo is enabled.
if
not
stream_buffer
and
request
.
echo
:
prefill_token_logprobs
=
content
[
"meta_info"
][
"prefill_token_logprobs"
]
prefill_top_logprobs
=
content
[
"meta_info"
][
"prefill_top_logprobs"
]
else
:
prefill_token_logprobs
=
None
prefill_top_logprobs
=
None
logprobs
=
to_openai_style_logprobs
(
prefill_token_logprobs
=
prefill_token_logprobs
,
prefill_top_logprobs
=
prefill_top_logprobs
,
decode_token_logprobs
=
content
[
"meta_info"
][
"decode_token_logprobs"
][
n_prev_token
:],
decode_top_logprobs
=
content
[
"meta_info"
][
"decode_top_logprobs"
][
n_prev_token
:
],
)
n_prev_token
=
len
(
content
[
"meta_info"
][
"decode_token_logprobs"
])
else
:
prefill_token_logprobs
=
None
prefill_top_logprobs
=
None
logprobs
=
to_openai_style_logprobs
(
prefill_token_logprobs
=
prefill_token_logprobs
,
prefill_top_logprobs
=
prefill_top_logprobs
,
decode_token_logprobs
=
content
[
"meta_info"
][
"decode_token_logprobs"
][
n_prev_token
:],
decode_top_logprobs
=
content
[
"meta_info"
][
"decode_top_logprobs"
][
n_prev_token
:
],
)
logprobs
=
None
n_prev_token
=
len
(
content
[
"meta_info"
][
"decode_token_logprobs"
])
else
:
logprobs
=
None
delta
=
text
[
len
(
stream_buffer
)
:]
stream_buffer
=
content
[
"text"
]
choice_data
=
CompletionResponseStreamChoice
(
index
=
0
,
text
=
delta
,
logprobs
=
logprobs
,
finish_reason
=
content
[
"meta_info"
][
"finish_reason"
],
)
chunk
=
CompletionStreamResponse
(
id
=
content
[
"meta_info"
][
"id"
],
object
=
"text_completion"
,
choices
=
[
choice_data
],
model
=
request
.
model
,
usage
=
UsageInfo
(
prompt_tokens
=
prompt_tokens
,
completion_tokens
=
completion_tokens
,
total_tokens
=
prompt_tokens
+
completion_tokens
,
),
)
yield
f
"data:
{
jsonify_pydantic_model
(
chunk
)
}
\n\n
"
delta
=
text
[
len
(
stream_buffer
)
:]
stream_buffer
=
content
[
"text"
]
choice_data
=
CompletionResponseStreamChoice
(
index
=
0
,
text
=
delta
,
logprobs
=
logprobs
,
finish_reason
=
content
[
"meta_info"
][
"finish_reason"
],
)
chunk
=
CompletionStreamResponse
(
id
=
content
[
"meta_info"
][
"id"
],
object
=
"text_completion"
,
choices
=
[
choice_data
],
model
=
request
.
model
,
usage
=
UsageInfo
(
prompt_tokens
=
prompt_tokens
,
completion_tokens
=
completion_tokens
,
total_tokens
=
prompt_tokens
+
completion_tokens
,
),
)
yield
f
"data:
{
chunk
.
model_dump_json
()
}
\n\n
"
except
ValueError
as
e
:
error
=
create_streaming_error_response
(
str
(
e
))
yield
f
"data:
{
error
}
\n\n
"
yield
"data: [DONE]
\n\n
"
return
StreamingResponse
(
generate_stream_resp
(),
media_type
=
"text/event-stream"
)
return
StreamingResponse
(
generate_stream_resp
(),
media_type
=
"text/event-stream"
,
background
=
tokenizer_manager
.
create_abort_task
(
adapted_request
))
# Non-streaming response.
ret
=
await
tokenizer_manager
.
generate_request
(
adapted_request
).
__anext__
()
ret
=
ret
[
0
]
if
isinstance
(
ret
,
list
)
else
ret
try
:
ret
=
await
tokenizer_manager
.
generate_request
(
adapted_request
,
raw_request
).
__anext__
()
except
ValueError
as
e
:
return
create_error_response
(
str
(
e
))
ret
=
ret
[
0
]
if
isinstance
(
ret
,
list
)
else
ret
prompt_tokens
=
ret
[
"meta_info"
][
"prompt_tokens"
]
completion_tokens
=
ret
[
"meta_info"
][
"completion_tokens"
]
text
=
ret
[
"text"
]
...
...
@@ -212,8 +245,8 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
request_json
=
await
raw_request
.
json
()
request
=
ChatCompletionRequest
(
**
request_json
)
# TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
assert
request
.
n
=
=
1
if
request
.
n
!=
1
:
return
create_error_response
(
"
n
!
= 1
is not supported"
)
# Prep the data needed for the underlying GenerateReqInput:
# - prompt: The full prompt string.
...
...
@@ -258,7 +291,6 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
},
stream
=
request
.
stream
,
)
adapted_request
.
post_init
()
if
adapted_request
.
stream
:
...
...
@@ -266,13 +298,29 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
is_first
=
True
stream_buffer
=
""
async
for
content
in
tokenizer_manager
.
generate_request
(
adapted_request
):
if
is_first
:
# First chunk with role
is_first
=
False
try
:
async
for
content
in
tokenizer_manager
.
generate_request
(
adapted_request
,
raw_request
):
if
is_first
:
# First chunk with role
is_first
=
False
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
0
,
delta
=
DeltaMessage
(
role
=
"assistant"
),
finish_reason
=
content
[
"meta_info"
][
"finish_reason"
],
)
chunk
=
ChatCompletionStreamResponse
(
id
=
content
[
"meta_info"
][
"id"
],
choices
=
[
choice_data
],
model
=
request
.
model
,
)
yield
f
"data:
{
chunk
.
model_dump_json
()
}
\n\n
"
text
=
content
[
"text"
]
delta
=
text
[
len
(
stream_buffer
)
:]
stream_buffer
=
text
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
0
,
delta
=
DeltaMessage
(
role
=
"assistant"
),
delta
=
DeltaMessage
(
content
=
delta
),
finish_reason
=
content
[
"meta_info"
][
"finish_reason"
],
)
chunk
=
ChatCompletionStreamResponse
(
...
...
@@ -280,28 +328,22 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
choices
=
[
choice_data
],
model
=
request
.
model
,
)
yield
f
"data:
{
jsonify_pydantic_model
(
chunk
)
}
\n\n
"
text
=
content
[
"text"
]
delta
=
text
[
len
(
stream_buffer
)
:]
stream_buffer
=
text
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
0
,
delta
=
DeltaMessage
(
content
=
delta
),
finish_reason
=
content
[
"meta_info"
][
"finish_reason"
],
)
chunk
=
ChatCompletionStreamResponse
(
id
=
content
[
"meta_info"
][
"id"
],
choices
=
[
choice_data
],
model
=
request
.
model
,
)
yield
f
"data:
{
jsonify_pydantic_model
(
chunk
)
}
\n\n
"
yield
f
"data:
{
chunk
.
model_dump_json
()
}
\n\n
"
except
ValueError
as
e
:
error
=
create_streaming_error_response
(
str
(
e
))
yield
f
"data:
{
error
}
\n\n
"
yield
"data: [DONE]
\n\n
"
return
StreamingResponse
(
generate_stream_resp
(),
media_type
=
"text/event-stream"
)
return
StreamingResponse
(
generate_stream_resp
(),
media_type
=
"text/event-stream"
,
background
=
tokenizer_manager
.
create_abort_task
(
adapted_request
))
# Non-streaming response.
ret
=
await
tokenizer_manager
.
generate_request
(
adapted_request
).
__anext__
()
try
:
ret
=
await
tokenizer_manager
.
generate_request
(
adapted_request
,
raw_request
).
__anext__
()
except
ValueError
as
e
:
return
create_error_response
(
str
(
e
))
prompt_tokens
=
ret
[
"meta_info"
][
"prompt_tokens"
]
completion_tokens
=
ret
[
"meta_info"
][
"completion_tokens"
]
choice_data
=
ChatCompletionResponseChoice
(
...
...
python/sglang/srt/openai_protocol.py
View file @
8dbdc018
...
...
@@ -7,6 +7,14 @@ from pydantic import BaseModel, Field
from
typing_extensions
import
Literal
class
ErrorResponse
(
BaseModel
):
object
:
str
=
"error"
message
:
str
type
:
str
param
:
Optional
[
str
]
=
None
code
:
int
class
LogProbs
(
BaseModel
):
text_offset
:
List
[
int
]
=
Field
(
default_factory
=
list
)
token_logprobs
:
List
[
Optional
[
float
]]
=
Field
(
default_factory
=
list
)
...
...
python/sglang/srt/server.py
View file @
8dbdc018
...
...
@@ -93,7 +93,8 @@ async def generate_request(obj: GenerateReqInput, request: Request):
yield
f
"data:
{
json
.
dumps
(
out
,
ensure_ascii
=
False
)
}
\n\n
"
yield
"data: [DONE]
\n\n
"
return
StreamingResponse
(
stream_results
(),
media_type
=
"text/event-stream"
)
return
StreamingResponse
(
stream_results
(),
media_type
=
"text/event-stream"
,
background
=
tokenizer_manager
.
create_abort_task
(
obj
))
else
:
try
:
ret
=
await
tokenizer_manager
.
generate_request
(
obj
,
request
).
__anext__
()
...
...
python/sglang/srt/utils.py
View file @
8dbdc018
...
...
@@ -392,14 +392,4 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
content
=
{
"detail"
:
"Invalid API Key"
},
)
response
=
await
call_next
(
request
)
return
response
# FIXME: Remove this once we drop support for pydantic 1.x
IS_PYDANTIC_1
=
int
(
pydantic
.
VERSION
.
split
(
"."
)[
0
])
==
1
def
jsonify_pydantic_model
(
obj
:
BaseModel
):
if
IS_PYDANTIC_1
:
return
obj
.
json
(
ensure_ascii
=
False
)
return
obj
.
model_dump_json
()
return
response
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment