Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
795eab6d
Unverified
Commit
795eab6d
authored
Aug 07, 2024
by
yichuan~
Committed by
GitHub
Aug 06, 2024
Browse files
Add support for Batch API test (#936)
parent
41bb1ab1
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
164 additions
and
4 deletions
+164
-4
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+0
-2
python/sglang/srt/openai_api/adapter.py
python/sglang/srt/openai_api/adapter.py
+23
-2
python/sglang/srt/openai_api/protocol.py
python/sglang/srt/openai_api/protocol.py
+6
-0
python/sglang/srt/server.py
python/sglang/srt/server.py
+7
-0
test/srt/test_openai_server.py
test/srt/test_openai_server.py
+128
-0
No files found.
python/sglang/srt/managers/tokenizer_manager.py
View file @
795eab6d
...
@@ -308,7 +308,6 @@ class TokenizerManager:
...
@@ -308,7 +308,6 @@ class TokenizerManager:
event
=
asyncio
.
Event
()
event
=
asyncio
.
Event
()
state
=
ReqState
([],
False
,
event
)
state
=
ReqState
([],
False
,
event
)
self
.
rid_to_state
[
rid
]
=
state
self
.
rid_to_state
[
rid
]
=
state
# Then wait for all responses
# Then wait for all responses
output_list
=
[]
output_list
=
[]
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
...
@@ -341,7 +340,6 @@ class TokenizerManager:
...
@@ -341,7 +340,6 @@ class TokenizerManager:
)
)
assert
state
.
finished
assert
state
.
finished
del
self
.
rid_to_state
[
rid
]
del
self
.
rid_to_state
[
rid
]
yield
output_list
yield
output_list
def
_validate_input_length
(
self
,
input_ids
:
List
[
int
]):
def
_validate_input_length
(
self
,
input_ids
:
List
[
int
]):
...
...
python/sglang/srt/openai_api/adapter.py
View file @
795eab6d
...
@@ -53,6 +53,7 @@ from sglang.srt.openai_api.protocol import (
...
@@ -53,6 +53,7 @@ from sglang.srt.openai_api.protocol import (
CompletionStreamResponse
,
CompletionStreamResponse
,
DeltaMessage
,
DeltaMessage
,
ErrorResponse
,
ErrorResponse
,
FileDeleteResponse
,
FileRequest
,
FileRequest
,
FileResponse
,
FileResponse
,
LogProbs
,
LogProbs
,
...
@@ -174,6 +175,20 @@ async def v1_files_create(file: UploadFile, purpose: str, file_storage_pth: str
...
@@ -174,6 +175,20 @@ async def v1_files_create(file: UploadFile, purpose: str, file_storage_pth: str
return
{
"error"
:
"Invalid input"
,
"details"
:
e
.
errors
()}
return
{
"error"
:
"Invalid input"
,
"details"
:
e
.
errors
()}
async
def
v1_delete_file
(
file_id
:
str
):
# Retrieve the file job from the in-memory storage
file_response
=
file_id_response
.
get
(
file_id
)
if
file_response
is
None
:
raise
HTTPException
(
status_code
=
404
,
detail
=
"File not found"
)
file_path
=
file_id_storage
.
get
(
file_id
)
if
file_path
is
None
:
raise
HTTPException
(
status_code
=
404
,
detail
=
"File not found"
)
os
.
remove
(
file_path
)
del
file_id_response
[
file_id
]
del
file_id_storage
[
file_id
]
return
FileDeleteResponse
(
id
=
file_id
,
deleted
=
True
)
async
def
v1_batches
(
tokenizer_manager
,
raw_request
:
Request
):
async
def
v1_batches
(
tokenizer_manager
,
raw_request
:
Request
):
try
:
try
:
body
=
await
raw_request
.
json
()
body
=
await
raw_request
.
json
()
...
@@ -287,6 +302,13 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
...
@@ -287,6 +302,13 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
retrieve_batch
=
batch_storage
[
batch_id
]
retrieve_batch
=
batch_storage
[
batch_id
]
retrieve_batch
.
output_file_id
=
output_file_id
retrieve_batch
.
output_file_id
=
output_file_id
file_id_storage
[
output_file_id
]
=
output_file_path
file_id_storage
[
output_file_id
]
=
output_file_path
file_id_response
[
output_file_id
]
=
FileResponse
(
id
=
output_file_id
,
bytes
=
os
.
path
.
getsize
(
output_file_path
),
created_at
=
int
(
time
.
time
()),
filename
=
f
"
{
output_file_id
}
.jsonl"
,
purpose
=
"batch_result"
,
)
# Update batch status to "completed"
# Update batch status to "completed"
retrieve_batch
.
status
=
"completed"
retrieve_batch
.
status
=
"completed"
retrieve_batch
.
completed_at
=
int
(
time
.
time
())
retrieve_batch
.
completed_at
=
int
(
time
.
time
())
...
@@ -380,7 +402,7 @@ def v1_generate_request(all_requests):
...
@@ -380,7 +402,7 @@ def v1_generate_request(all_requests):
else
:
else
:
prompt_kwargs
=
{
"input_ids"
:
prompt
}
prompt_kwargs
=
{
"input_ids"
:
prompt
}
else
:
else
:
if
isinstance
(
prompts
[
0
],
str
)
or
isinstance
(
propmt
[
0
][
0
],
str
):
if
isinstance
(
prompts
[
0
],
str
):
prompt_kwargs
=
{
"text"
:
prompts
}
prompt_kwargs
=
{
"text"
:
prompts
}
else
:
else
:
prompt_kwargs
=
{
"input_ids"
:
prompts
}
prompt_kwargs
=
{
"input_ids"
:
prompts
}
...
@@ -931,7 +953,6 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
...
@@ -931,7 +953,6 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
).
__anext__
()
).
__anext__
()
except
ValueError
as
e
:
except
ValueError
as
e
:
return
create_error_response
(
str
(
e
))
return
create_error_response
(
str
(
e
))
if
not
isinstance
(
ret
,
list
):
if
not
isinstance
(
ret
,
list
):
ret
=
[
ret
]
ret
=
[
ret
]
...
...
python/sglang/srt/openai_api/protocol.py
View file @
795eab6d
...
@@ -95,6 +95,12 @@ class FileResponse(BaseModel):
...
@@ -95,6 +95,12 @@ class FileResponse(BaseModel):
purpose
:
str
purpose
:
str
class
FileDeleteResponse
(
BaseModel
):
id
:
str
object
:
str
=
"file"
deleted
:
bool
class
BatchRequest
(
BaseModel
):
class
BatchRequest
(
BaseModel
):
input_file_id
:
(
input_file_id
:
(
str
# The ID of an uploaded file that contains requests for the new batch
str
# The ID of an uploaded file that contains requests for the new batch
...
...
python/sglang/srt/server.py
View file @
795eab6d
...
@@ -59,6 +59,7 @@ from sglang.srt.openai_api.adapter import (
...
@@ -59,6 +59,7 @@ from sglang.srt.openai_api.adapter import (
v1_batches
,
v1_batches
,
v1_chat_completions
,
v1_chat_completions
,
v1_completions
,
v1_completions
,
v1_delete_file
,
v1_files_create
,
v1_files_create
,
v1_retrieve_batch
,
v1_retrieve_batch
,
v1_retrieve_file
,
v1_retrieve_file
,
...
@@ -175,6 +176,12 @@ async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("bat
...
@@ -175,6 +176,12 @@ async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("bat
)
)
@
app
.
delete
(
"/v1/files/{file_id}"
)
async
def
delete_file
(
file_id
:
str
):
# https://platform.openai.com/docs/api-reference/files/delete
return
await
v1_delete_file
(
file_id
)
@
app
.
post
(
"/v1/batches"
)
@
app
.
post
(
"/v1/batches"
)
async
def
openai_v1_batches
(
raw_request
:
Request
):
async
def
openai_v1_batches
(
raw_request
:
Request
):
return
await
v1_batches
(
tokenizer_manager
,
raw_request
)
return
await
v1_batches
(
tokenizer_manager
,
raw_request
)
...
...
test/srt/test_openai_server.py
View file @
795eab6d
import
json
import
json
import
time
import
unittest
import
unittest
import
openai
import
openai
...
@@ -207,6 +208,129 @@ class TestOpenAIServer(unittest.TestCase):
...
@@ -207,6 +208,129 @@ class TestOpenAIServer(unittest.TestCase):
assert
response
.
id
assert
response
.
id
assert
response
.
created
assert
response
.
created
def
run_batch
(
self
,
mode
):
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
if
mode
==
"completion"
:
input_file_path
=
"complete_input.jsonl"
# write content to input file
content
=
[
{
"custom_id"
:
"request-1"
,
"method"
:
"POST"
,
"url"
:
"/v1/completions"
,
"body"
:
{
"model"
:
"gpt-3.5-turbo-instruct"
,
"prompt"
:
"List 3 names of famous soccer player: "
,
"max_tokens"
:
20
,
},
},
{
"custom_id"
:
"request-2"
,
"method"
:
"POST"
,
"url"
:
"/v1/completions"
,
"body"
:
{
"model"
:
"gpt-3.5-turbo-instruct"
,
"prompt"
:
"List 6 names of famous basketball player: "
,
"max_tokens"
:
40
,
},
},
{
"custom_id"
:
"request-3"
,
"method"
:
"POST"
,
"url"
:
"/v1/completions"
,
"body"
:
{
"model"
:
"gpt-3.5-turbo-instruct"
,
"prompt"
:
"List 6 names of famous tenniss player: "
,
"max_tokens"
:
40
,
},
},
]
else
:
input_file_path
=
"chat_input.jsonl"
content
=
[
{
"custom_id"
:
"request-1"
,
"method"
:
"POST"
,
"url"
:
"/v1/chat/completions"
,
"body"
:
{
"model"
:
"gpt-3.5-turbo-0125"
,
"messages"
:
[
{
"role"
:
"system"
,
"content"
:
"You are a helpful assistant."
,
},
{
"role"
:
"user"
,
"content"
:
"Hello! List 3 NBA players and tell a story"
,
},
],
"max_tokens"
:
30
,
},
},
{
"custom_id"
:
"request-2"
,
"method"
:
"POST"
,
"url"
:
"/v1/chat/completions"
,
"body"
:
{
"model"
:
"gpt-3.5-turbo-0125"
,
"messages"
:
[
{
"role"
:
"system"
,
"content"
:
"You are an assistant. "
},
{
"role"
:
"user"
,
"content"
:
"Hello! List three capital and tell a story"
,
},
],
"max_tokens"
:
50
,
},
},
]
with
open
(
input_file_path
,
"w"
)
as
file
:
for
line
in
content
:
file
.
write
(
json
.
dumps
(
line
)
+
"
\n
"
)
with
open
(
input_file_path
,
"rb"
)
as
file
:
uploaded_file
=
client
.
files
.
create
(
file
=
file
,
purpose
=
"batch"
)
if
mode
==
"completion"
:
endpoint
=
"/v1/completions"
elif
mode
==
"chat"
:
endpoint
=
"/v1/chat/completions"
completion_window
=
"24h"
batch_job
=
client
.
batches
.
create
(
input_file_id
=
uploaded_file
.
id
,
endpoint
=
endpoint
,
completion_window
=
completion_window
,
)
while
batch_job
.
status
not
in
[
"completed"
,
"failed"
,
"cancelled"
]:
time
.
sleep
(
3
)
print
(
f
"Batch job status:
{
batch_job
.
status
}
...trying again in 3 seconds..."
)
batch_job
=
client
.
batches
.
retrieve
(
batch_job
.
id
)
assert
batch_job
.
status
==
"completed"
assert
batch_job
.
request_counts
.
completed
==
len
(
content
)
assert
batch_job
.
request_counts
.
failed
==
0
assert
batch_job
.
request_counts
.
total
==
len
(
content
)
result_file_id
=
batch_job
.
output_file_id
file_response
=
client
.
files
.
content
(
result_file_id
)
result_content
=
file_response
.
read
()
if
mode
==
"completion"
:
result_file_name
=
"batch_job_complete_results.jsonl"
else
:
result_file_name
=
"batch_job_chat_results.jsonl"
with
open
(
result_file_name
,
"wb"
)
as
file
:
file
.
write
(
result_content
)
results
=
[]
with
open
(
result_file_name
,
"r"
,
encoding
=
"utf-8"
)
as
file
:
for
line
in
file
:
json_object
=
json
.
loads
(
line
.
strip
())
results
.
append
(
json_object
)
for
delete_fid
in
[
uploaded_file
.
id
,
result_file_id
]:
del_pesponse
=
client
.
files
.
delete
(
delete_fid
)
assert
del_pesponse
.
deleted
assert
len
(
results
)
==
len
(
content
)
def
test_completion
(
self
):
def
test_completion
(
self
):
for
echo
in
[
False
,
True
]:
for
echo
in
[
False
,
True
]:
for
logprobs
in
[
None
,
5
]:
for
logprobs
in
[
None
,
5
]:
...
@@ -237,6 +361,10 @@ class TestOpenAIServer(unittest.TestCase):
...
@@ -237,6 +361,10 @@ class TestOpenAIServer(unittest.TestCase):
for
logprobs
in
[
None
,
5
]:
for
logprobs
in
[
None
,
5
]:
self
.
run_chat_completion_stream
(
logprobs
)
self
.
run_chat_completion_stream
(
logprobs
)
def
test_batch
(
self
):
for
mode
in
[
"completion"
,
"chat"
]:
self
.
run_batch
(
mode
)
def
test_regex
(
self
):
def
test_regex
(
self
):
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment