Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
2f1d9283
Unverified
Commit
2f1d9283
authored
Aug 27, 2024
by
caiyueliang
Committed by
GitHub
Aug 26, 2024
Browse files
[FEAT] Support batches cancel (#1222)
Co-authored-by:
Yineng Zhang
<
me@zhyncs.com
>
parent
c61a1b6f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
122 additions
and
6 deletions
+122
-6
python/sglang/srt/openai_api/adapter.py
python/sglang/srt/openai_api/adapter.py
+83
-4
python/sglang/srt/server.py
python/sglang/srt/server.py
+7
-0
test/srt/test_openai_server.py
test/srt/test_openai_server.py
+32
-2
No files found.
python/sglang/srt/openai_api/adapter.py
View file @
2f1d9283
...
@@ -275,10 +275,12 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
...
@@ -275,10 +275,12 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
end_point
=
batch_storage
[
batch_id
].
endpoint
end_point
=
batch_storage
[
batch_id
].
endpoint
file_request_list
=
[]
file_request_list
=
[]
all_requests
=
[]
all_requests
=
[]
request_ids
=
[]
for
line
in
lines
:
for
line
in
lines
:
request_data
=
json
.
loads
(
line
)
request_data
=
json
.
loads
(
line
)
file_request_list
.
append
(
request_data
)
file_request_list
.
append
(
request_data
)
body
=
request_data
[
"body"
]
body
=
request_data
[
"body"
]
request_ids
.
append
(
request_data
[
"custom_id"
])
# Although streaming is supported for standalone completions, it is not supported in
# Although streaming is supported for standalone completions, it is not supported in
# batch mode (multiple completions in single request).
# batch mode (multiple completions in single request).
...
@@ -289,12 +291,16 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
...
@@ -289,12 +291,16 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
all_requests
.
append
(
ChatCompletionRequest
(
**
body
))
all_requests
.
append
(
ChatCompletionRequest
(
**
body
))
elif
end_point
==
"/v1/completions"
:
elif
end_point
==
"/v1/completions"
:
all_requests
.
append
(
CompletionRequest
(
**
body
))
all_requests
.
append
(
CompletionRequest
(
**
body
))
if
end_point
==
"/v1/chat/completions"
:
if
end_point
==
"/v1/chat/completions"
:
adapted_request
,
request
=
v1_chat_generate_request
(
adapted_request
,
request
=
v1_chat_generate_request
(
all_requests
,
tokenizer_manager
all_requests
,
tokenizer_manager
,
request_ids
=
request_ids
)
)
elif
end_point
==
"/v1/completions"
:
elif
end_point
==
"/v1/completions"
:
adapted_request
,
request
=
v1_generate_request
(
all_requests
)
adapted_request
,
request
=
v1_generate_request
(
all_requests
,
request_ids
=
request_ids
)
try
:
try
:
ret
=
await
tokenizer_manager
.
generate_request
(
adapted_request
).
__anext__
()
ret
=
await
tokenizer_manager
.
generate_request
(
adapted_request
).
__anext__
()
if
not
isinstance
(
ret
,
list
):
if
not
isinstance
(
ret
,
list
):
...
@@ -326,6 +332,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
...
@@ -326,6 +332,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
}
}
all_ret
.
append
(
response_json
)
all_ret
.
append
(
response_json
)
completed_requests
+=
1
completed_requests
+=
1
# Write results to a new file
# Write results to a new file
output_file_id
=
f
"backend_result_file-
{
uuid
.
uuid4
()
}
"
output_file_id
=
f
"backend_result_file-
{
uuid
.
uuid4
()
}
"
global
storage_dir
global
storage_dir
...
@@ -372,6 +379,72 @@ async def v1_retrieve_batch(batch_id: str):
...
@@ -372,6 +379,72 @@ async def v1_retrieve_batch(batch_id: str):
return
batch_response
return
batch_response
async
def
v1_cancel_batch
(
tokenizer_manager
,
batch_id
:
str
):
# Retrieve the batch job from the in-memory storage
batch_response
=
batch_storage
.
get
(
batch_id
)
if
batch_response
is
None
:
raise
HTTPException
(
status_code
=
404
,
detail
=
"Batch not found"
)
# Only do cancal when status is "validating" or "in_progress"
if
batch_response
.
status
in
[
"validating"
,
"in_progress"
]:
# Start cancelling the batch asynchronously
asyncio
.
create_task
(
cancel_batch
(
tokenizer_manager
=
tokenizer_manager
,
batch_id
=
batch_id
,
input_file_id
=
batch_response
.
input_file_id
,
)
)
# Update batch status to "cancelling"
batch_response
.
status
=
"cancelling"
return
batch_response
else
:
raise
HTTPException
(
status_code
=
500
,
detail
=
f
"Current status is
{
batch_response
.
status
}
, no need to cancel"
,
)
async
def
cancel_batch
(
tokenizer_manager
,
batch_id
:
str
,
input_file_id
:
str
):
try
:
# Update the batch status to "cancelling"
batch_storage
[
batch_id
].
status
=
"cancelling"
# Retrieve the input file content
input_file_request
=
file_id_request
.
get
(
input_file_id
)
if
not
input_file_request
:
raise
ValueError
(
"Input file not found"
)
# Parse the JSONL file and process each request
input_file_path
=
file_id_storage
.
get
(
input_file_id
)
with
open
(
input_file_path
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
lines
=
f
.
readlines
()
file_request_list
=
[]
request_ids
=
[]
for
line
in
lines
:
request_data
=
json
.
loads
(
line
)
file_request_list
.
append
(
request_data
)
request_ids
.
append
(
request_data
[
"custom_id"
])
# Cancel requests by request_ids
for
rid
in
request_ids
:
tokenizer_manager
.
abort_request
(
rid
=
rid
)
retrieve_batch
=
batch_storage
[
batch_id
]
retrieve_batch
.
status
=
"cancelled"
except
Exception
as
e
:
logger
.
error
(
"error in SGLang:"
,
e
)
# Update batch status to "failed"
retrieve_batch
=
batch_storage
[
batch_id
]
retrieve_batch
.
status
=
"failed"
retrieve_batch
.
failed_at
=
int
(
time
.
time
())
retrieve_batch
.
errors
=
{
"message"
:
str
(
e
)}
async
def
v1_retrieve_file
(
file_id
:
str
):
async
def
v1_retrieve_file
(
file_id
:
str
):
# Retrieve the batch job from the in-memory storage
# Retrieve the batch job from the in-memory storage
file_response
=
file_id_response
.
get
(
file_id
)
file_response
=
file_id_response
.
get
(
file_id
)
...
@@ -392,7 +465,9 @@ async def v1_retrieve_file_content(file_id: str):
...
@@ -392,7 +465,9 @@ async def v1_retrieve_file_content(file_id: str):
return
StreamingResponse
(
iter_file
(),
media_type
=
"application/octet-stream"
)
return
StreamingResponse
(
iter_file
(),
media_type
=
"application/octet-stream"
)
def
v1_generate_request
(
all_requests
:
List
[
CompletionRequest
]):
def
v1_generate_request
(
all_requests
:
List
[
CompletionRequest
],
request_ids
:
List
[
str
]
=
None
):
prompts
=
[]
prompts
=
[]
sampling_params_list
=
[]
sampling_params_list
=
[]
return_logprobs
=
[]
return_logprobs
=
[]
...
@@ -464,6 +539,7 @@ def v1_generate_request(all_requests: List[CompletionRequest]):
...
@@ -464,6 +539,7 @@ def v1_generate_request(all_requests: List[CompletionRequest]):
logprob_start_len
=
logprob_start_lens
,
logprob_start_len
=
logprob_start_lens
,
return_text_in_logprobs
=
True
,
return_text_in_logprobs
=
True
,
stream
=
all_requests
[
0
].
stream
,
stream
=
all_requests
[
0
].
stream
,
rid
=
request_ids
,
)
)
if
len
(
all_requests
)
==
1
:
if
len
(
all_requests
)
==
1
:
...
@@ -746,7 +822,9 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
...
@@ -746,7 +822,9 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
def
v1_chat_generate_request
(
def
v1_chat_generate_request
(
all_requests
:
List
[
ChatCompletionRequest
],
tokenizer_manager
all_requests
:
List
[
ChatCompletionRequest
],
tokenizer_manager
,
request_ids
:
List
[
str
]
=
None
,
):
):
input_ids
=
[]
input_ids
=
[]
sampling_params_list
=
[]
sampling_params_list
=
[]
...
@@ -834,6 +912,7 @@ def v1_chat_generate_request(
...
@@ -834,6 +912,7 @@ def v1_chat_generate_request(
top_logprobs_num
=
top_logprobs_nums
,
top_logprobs_num
=
top_logprobs_nums
,
stream
=
all_requests
[
0
].
stream
,
stream
=
all_requests
[
0
].
stream
,
return_text_in_logprobs
=
True
,
return_text_in_logprobs
=
True
,
rid
=
request_ids
,
)
)
if
len
(
all_requests
)
==
1
:
if
len
(
all_requests
)
==
1
:
return
adapted_request
,
all_requests
[
0
]
return
adapted_request
,
all_requests
[
0
]
...
...
python/sglang/srt/server.py
View file @
2f1d9283
...
@@ -59,6 +59,7 @@ from sglang.srt.managers.tokenizer_manager import TokenizerManager
...
@@ -59,6 +59,7 @@ from sglang.srt.managers.tokenizer_manager import TokenizerManager
from
sglang.srt.openai_api.adapter
import
(
from
sglang.srt.openai_api.adapter
import
(
load_chat_template_for_openai_api
,
load_chat_template_for_openai_api
,
v1_batches
,
v1_batches
,
v1_cancel_batch
,
v1_chat_completions
,
v1_chat_completions
,
v1_completions
,
v1_completions
,
v1_delete_file
,
v1_delete_file
,
...
@@ -246,6 +247,12 @@ async def openai_v1_batches(raw_request: Request):
...
@@ -246,6 +247,12 @@ async def openai_v1_batches(raw_request: Request):
return
await
v1_batches
(
tokenizer_manager
,
raw_request
)
return
await
v1_batches
(
tokenizer_manager
,
raw_request
)
@
app
.
post
(
"/v1/batches/{batch_id}/cancel"
)
async
def
cancel_batches
(
batch_id
:
str
):
# https://platform.openai.com/docs/api-reference/batch/cancel
return
await
v1_cancel_batch
(
tokenizer_manager
,
batch_id
)
@
app
.
get
(
"/v1/batches/{batch_id}"
)
@
app
.
get
(
"/v1/batches/{batch_id}"
)
async
def
retrieve_batch
(
batch_id
:
str
):
async
def
retrieve_batch
(
batch_id
:
str
):
return
await
v1_retrieve_batch
(
batch_id
)
return
await
v1_retrieve_batch
(
batch_id
)
...
...
test/srt/test_openai_server.py
View file @
2f1d9283
...
@@ -256,8 +256,7 @@ class TestOpenAIServer(unittest.TestCase):
...
@@ -256,8 +256,7 @@ class TestOpenAIServer(unittest.TestCase):
index
,
True
index
,
True
),
f
"index
{
index
}
is not found in the response"
),
f
"index
{
index
}
is not found in the response"
def
run_batch
(
self
,
mode
):
def
_create_batch
(
self
,
mode
,
client
):
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
if
mode
==
"completion"
:
if
mode
==
"completion"
:
input_file_path
=
"complete_input.jsonl"
input_file_path
=
"complete_input.jsonl"
# write content to input file
# write content to input file
...
@@ -333,9 +332,11 @@ class TestOpenAIServer(unittest.TestCase):
...
@@ -333,9 +332,11 @@ class TestOpenAIServer(unittest.TestCase):
},
},
},
},
]
]
with
open
(
input_file_path
,
"w"
)
as
file
:
with
open
(
input_file_path
,
"w"
)
as
file
:
for
line
in
content
:
for
line
in
content
:
file
.
write
(
json
.
dumps
(
line
)
+
"
\n
"
)
file
.
write
(
json
.
dumps
(
line
)
+
"
\n
"
)
with
open
(
input_file_path
,
"rb"
)
as
file
:
with
open
(
input_file_path
,
"rb"
)
as
file
:
uploaded_file
=
client
.
files
.
create
(
file
=
file
,
purpose
=
"batch"
)
uploaded_file
=
client
.
files
.
create
(
file
=
file
,
purpose
=
"batch"
)
if
mode
==
"completion"
:
if
mode
==
"completion"
:
...
@@ -348,6 +349,13 @@ class TestOpenAIServer(unittest.TestCase):
...
@@ -348,6 +349,13 @@ class TestOpenAIServer(unittest.TestCase):
endpoint
=
endpoint
,
endpoint
=
endpoint
,
completion_window
=
completion_window
,
completion_window
=
completion_window
,
)
)
return
batch_job
,
content
def
run_batch
(
self
,
mode
):
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
batch_job
,
content
=
self
.
_create_batch
(
mode
=
mode
,
client
=
client
)
while
batch_job
.
status
not
in
[
"completed"
,
"failed"
,
"cancelled"
]:
while
batch_job
.
status
not
in
[
"completed"
,
"failed"
,
"cancelled"
]:
time
.
sleep
(
3
)
time
.
sleep
(
3
)
print
(
print
(
...
@@ -371,6 +379,24 @@ class TestOpenAIServer(unittest.TestCase):
...
@@ -371,6 +379,24 @@ class TestOpenAIServer(unittest.TestCase):
]
]
assert
len
(
results
)
==
len
(
content
)
assert
len
(
results
)
==
len
(
content
)
def
run_cancel_batch
(
self
,
mode
):
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
batch_job
,
_
=
self
.
_create_batch
(
mode
=
mode
,
client
=
client
)
assert
batch_job
.
status
not
in
[
"cancelling"
,
"cancelled"
]
batch_job
=
client
.
batches
.
cancel
(
batch_id
=
batch_job
.
id
)
assert
batch_job
.
status
==
"cancelling"
while
batch_job
.
status
not
in
[
"failed"
,
"cancelled"
]:
batch_job
=
client
.
batches
.
retrieve
(
batch_job
.
id
)
print
(
f
"Batch job status:
{
batch_job
.
status
}
...trying again in 3 seconds..."
)
time
.
sleep
(
3
)
assert
batch_job
.
status
==
"cancelled"
def
test_completion
(
self
):
def
test_completion
(
self
):
for
echo
in
[
False
,
True
]:
for
echo
in
[
False
,
True
]:
for
logprobs
in
[
None
,
5
]:
for
logprobs
in
[
None
,
5
]:
...
@@ -414,6 +440,10 @@ class TestOpenAIServer(unittest.TestCase):
...
@@ -414,6 +440,10 @@ class TestOpenAIServer(unittest.TestCase):
for
mode
in
[
"completion"
,
"chat"
]:
for
mode
in
[
"completion"
,
"chat"
]:
self
.
run_batch
(
mode
)
self
.
run_batch
(
mode
)
def
test_calcel_batch
(
self
):
for
mode
in
[
"completion"
,
"chat"
]:
self
.
run_cancel_batch
(
mode
)
def
test_regex
(
self
):
def
test_regex
(
self
):
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment