Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
084fa54d
Unverified
Commit
084fa54d
authored
Jul 30, 2024
by
yichuan~
Committed by
GitHub
Jul 29, 2024
Browse files
Add support for OpenAI API : offline batch(file) processing (#699)
Co-authored-by:
hnyls2002
<
hnyls2002@gmail.com
>
parent
eba458bd
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
839 additions
and
154 deletions
+839
-154
.pre-commit-config.yaml
.pre-commit-config.yaml
+1
-1
examples/usage/openai_batch_chat.py
examples/usage/openai_batch_chat.py
+86
-0
examples/usage/openai_batch_complete.py
examples/usage/openai_batch_complete.py
+86
-0
examples/usage/openai_parallel_sample.py
examples/usage/openai_parallel_sample.py
+37
-0
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+20
-2
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+13
-11
python/sglang/srt/openai_api/adapter.py
python/sglang/srt/openai_api/adapter.py
+505
-139
python/sglang/srt/openai_api/protocol.py
python/sglang/srt/openai_api/protocol.py
+49
-0
python/sglang/srt/server.py
python/sglang/srt/server.py
+35
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+7
-0
No files found.
.pre-commit-config.yaml
View file @
084fa54d
...
...
@@ -4,6 +4,6 @@ repos:
hooks
:
-
id
:
isort
-
repo
:
https://github.com/psf/black
rev
:
stable
rev
:
24.4.2
hooks
:
-
id
:
black
examples/usage/openai_batch_chat.py
0 → 100644
View file @
084fa54d
import
json
import
os
import
time
import
openai
from
openai
import
OpenAI
class
OpenAIBatchProcessor
:
def
__init__
(
self
,
api_key
):
# client = OpenAI(api_key=api_key)
client
=
openai
.
Client
(
base_url
=
"http://127.0.0.1:30000/v1"
,
api_key
=
"EMPTY"
)
self
.
client
=
client
def
process_batch
(
self
,
input_file_path
,
endpoint
,
completion_window
):
# Upload the input file
with
open
(
input_file_path
,
"rb"
)
as
file
:
uploaded_file
=
self
.
client
.
files
.
create
(
file
=
file
,
purpose
=
"batch"
)
# Create the batch job
batch_job
=
self
.
client
.
batches
.
create
(
input_file_id
=
uploaded_file
.
id
,
endpoint
=
endpoint
,
completion_window
=
completion_window
,
)
# Monitor the batch job status
while
batch_job
.
status
not
in
[
"completed"
,
"failed"
,
"cancelled"
]:
time
.
sleep
(
3
)
# Wait for 3 seconds before checking the status again
print
(
f
"Batch job status:
{
batch_job
.
status
}
...trying again in 3 seconds..."
)
batch_job
=
self
.
client
.
batches
.
retrieve
(
batch_job
.
id
)
# Check the batch job status and errors
if
batch_job
.
status
==
"failed"
:
print
(
f
"Batch job failed with status:
{
batch_job
.
status
}
"
)
print
(
f
"Batch job errors:
{
batch_job
.
errors
}
"
)
return
None
# If the batch job is completed, process the results
if
batch_job
.
status
==
"completed"
:
# print result of batch job
print
(
"batch"
,
batch_job
.
request_counts
)
result_file_id
=
batch_job
.
output_file_id
# Retrieve the file content from the server
file_response
=
self
.
client
.
files
.
content
(
result_file_id
)
result_content
=
file_response
.
read
()
# Read the content of the file
# Save the content to a local file
result_file_name
=
"batch_job_chat_results.jsonl"
with
open
(
result_file_name
,
"wb"
)
as
file
:
file
.
write
(
result_content
)
# Write the binary content to the file
# Load data from the saved JSONL file
results
=
[]
with
open
(
result_file_name
,
"r"
,
encoding
=
"utf-8"
)
as
file
:
for
line
in
file
:
json_object
=
json
.
loads
(
line
.
strip
()
)
# Parse each line as a JSON object
results
.
append
(
json_object
)
return
results
else
:
print
(
f
"Batch job failed with status:
{
batch_job
.
status
}
"
)
return
None
# Initialize the OpenAIBatchProcessor
api_key
=
os
.
environ
.
get
(
"OPENAI_API_KEY"
)
processor
=
OpenAIBatchProcessor
(
api_key
)
# Process the batch job
input_file_path
=
"input.jsonl"
endpoint
=
"/v1/chat/completions"
completion_window
=
"24h"
# Process the batch job
results
=
processor
.
process_batch
(
input_file_path
,
endpoint
,
completion_window
)
# Print the results
print
(
results
)
examples/usage/openai_batch_complete.py
0 → 100644
View file @
084fa54d
import
json
import
os
import
time
import
openai
from
openai
import
OpenAI
class
OpenAIBatchProcessor
:
def
__init__
(
self
,
api_key
):
# client = OpenAI(api_key=api_key)
client
=
openai
.
Client
(
base_url
=
"http://127.0.0.1:30000/v1"
,
api_key
=
"EMPTY"
)
self
.
client
=
client
def
process_batch
(
self
,
input_file_path
,
endpoint
,
completion_window
):
# Upload the input file
with
open
(
input_file_path
,
"rb"
)
as
file
:
uploaded_file
=
self
.
client
.
files
.
create
(
file
=
file
,
purpose
=
"batch"
)
# Create the batch job
batch_job
=
self
.
client
.
batches
.
create
(
input_file_id
=
uploaded_file
.
id
,
endpoint
=
endpoint
,
completion_window
=
completion_window
,
)
# Monitor the batch job status
while
batch_job
.
status
not
in
[
"completed"
,
"failed"
,
"cancelled"
]:
time
.
sleep
(
3
)
# Wait for 3 seconds before checking the status again
print
(
f
"Batch job status:
{
batch_job
.
status
}
...trying again in 3 seconds..."
)
batch_job
=
self
.
client
.
batches
.
retrieve
(
batch_job
.
id
)
# Check the batch job status and errors
if
batch_job
.
status
==
"failed"
:
print
(
f
"Batch job failed with status:
{
batch_job
.
status
}
"
)
print
(
f
"Batch job errors:
{
batch_job
.
errors
}
"
)
return
None
# If the batch job is completed, process the results
if
batch_job
.
status
==
"completed"
:
# print result of batch job
print
(
"batch"
,
batch_job
.
request_counts
)
result_file_id
=
batch_job
.
output_file_id
# Retrieve the file content from the server
file_response
=
self
.
client
.
files
.
content
(
result_file_id
)
result_content
=
file_response
.
read
()
# Read the content of the file
# Save the content to a local file
result_file_name
=
"batch_job_complete_results.jsonl"
with
open
(
result_file_name
,
"wb"
)
as
file
:
file
.
write
(
result_content
)
# Write the binary content to the file
# Load data from the saved JSONL file
results
=
[]
with
open
(
result_file_name
,
"r"
,
encoding
=
"utf-8"
)
as
file
:
for
line
in
file
:
json_object
=
json
.
loads
(
line
.
strip
()
)
# Parse each line as a JSON object
results
.
append
(
json_object
)
return
results
else
:
print
(
f
"Batch job failed with status:
{
batch_job
.
status
}
"
)
return
None
# Initialize the OpenAIBatchProcessor
api_key
=
os
.
environ
.
get
(
"OPENAI_API_KEY"
)
processor
=
OpenAIBatchProcessor
(
api_key
)
# Process the batch job
input_file_path
=
"input_complete.jsonl"
endpoint
=
"/v1/completions"
completion_window
=
"24h"
# Process the batch job
results
=
processor
.
process_batch
(
input_file_path
,
endpoint
,
completion_window
)
# Print the results
print
(
results
)
examples/usage/openai_parallel_sample.py
View file @
084fa54d
...
...
@@ -13,6 +13,17 @@ response = client.completions.create(
print
(
response
)
# Text completion
response
=
client
.
completions
.
create
(
model
=
"default"
,
prompt
=
"I am a robot and I want to study like humans. Now let's tell a story. Once upon a time, there was a little"
,
n
=
1
,
temperature
=
0.8
,
max_tokens
=
32
,
)
print
(
response
)
# Text completion
response
=
client
.
completions
.
create
(
model
=
"default"
,
...
...
@@ -24,6 +35,17 @@ response = client.completions.create(
print
(
response
)
# Text completion
response
=
client
.
completions
.
create
(
model
=
"default"
,
prompt
=
[
"The name of the famous soccer player is"
],
n
=
1
,
temperature
=
0.8
,
max_tokens
=
128
,
)
print
(
response
)
# Text completion
response
=
client
.
completions
.
create
(
model
=
"default"
,
...
...
@@ -60,6 +82,21 @@ response = client.completions.create(
)
print
(
response
)
# Chat completion
response
=
client
.
chat
.
completions
.
create
(
model
=
"default"
,
messages
=
[
{
"role"
:
"system"
,
"content"
:
"You are a helpful AI assistant"
},
{
"role"
:
"user"
,
"content"
:
"List 3 countries and their capitals."
},
],
temperature
=
0.8
,
max_tokens
=
64
,
logprobs
=
True
,
n
=
1
,
)
print
(
response
)
# Chat completion
response
=
client
.
chat
.
completions
.
create
(
model
=
"default"
,
...
...
python/sglang/srt/managers/io_struct.py
View file @
084fa54d
...
...
@@ -79,8 +79,26 @@ class GenerateReqInput:
if
self
.
top_logprobs_num
is
None
:
self
.
top_logprobs_num
=
0
else
:
parallel_sample_num
=
self
.
sampling_params
.
get
(
"n"
,
1
)
parallel_sample_num_list
=
[]
if
isinstance
(
self
.
sampling_params
,
dict
):
parallel_sample_num
=
self
.
sampling_params
.
get
(
"n"
,
1
)
elif
isinstance
(
self
.
sampling_params
,
list
):
for
sp
in
self
.
sampling_params
:
parallel_sample_num
=
sp
.
get
(
"n"
,
1
)
parallel_sample_num_list
.
append
(
parallel_sample_num
)
parallel_sample_num
=
max
(
parallel_sample_num_list
)
all_equal
=
all
(
element
==
parallel_sample_num
for
element
in
parallel_sample_num_list
)
if
parallel_sample_num
>
1
and
(
not
all_equal
):
## TODO cope with the case that the parallel_sample_num is different for different samples
raise
ValueError
(
"The parallel_sample_num should be the same for all samples in sample params."
)
else
:
parallel_sample_num
=
1
self
.
parallel_sample_num
=
parallel_sample_num
if
parallel_sample_num
!=
1
:
# parallel sampling +1 represents the original prefill stage
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
084fa54d
...
...
@@ -84,6 +84,7 @@ class TokenizerManager:
trust_remote_code
=
server_args
.
trust_remote_code
,
model_overide_args
=
model_overide_args
,
)
if
server_args
.
context_length
is
not
None
:
self
.
context_len
=
server_args
.
context_length
else
:
...
...
@@ -152,31 +153,33 @@ class TokenizerManager:
self
,
obj
,
request
,
index
=
None
,
is_cache_for_prefill
=
False
):
if
not
is_cache_for_prefill
:
rid
=
obj
.
rid
if
index
is
None
else
obj
.
rid
[
index
]
input_text
=
obj
.
text
if
index
is
None
else
obj
.
text
[
index
]
not_use_index
=
not
(
index
is
not
None
)
rid
=
obj
.
rid
if
not_use_index
else
obj
.
rid
[
index
]
input_text
=
obj
.
text
if
not_use_index
else
obj
.
text
[
index
]
input_ids
=
(
self
.
tokenizer
.
encode
(
input_text
)
if
obj
.
input_ids
is
None
else
obj
.
input_ids
)
if
index
is
not
None
and
obj
.
input_ids
:
if
not
not_use_index
and
obj
.
input_ids
:
input_ids
=
obj
.
input_ids
[
index
]
self
.
_validate_input_length
(
input_ids
)
sampling_params
=
self
.
_get_sampling_params
(
obj
.
sampling_params
if
index
is
None
else
obj
.
sampling_params
[
index
]
obj
.
sampling_params
if
not_use_index
else
obj
.
sampling_params
[
index
]
)
pixel_values
,
image_hash
,
image_size
=
await
self
.
_get_pixel_values
(
obj
.
image_data
if
index
is
None
else
obj
.
image_data
[
index
]
obj
.
image_data
if
not_use_index
else
obj
.
image_data
[
index
]
)
return_logprob
=
(
obj
.
return_logprob
if
index
is
None
else
obj
.
return_logprob
[
index
]
obj
.
return_logprob
if
not_use_index
else
obj
.
return_logprob
[
index
]
)
logprob_start_len
=
(
obj
.
logprob_start_len
if
index
is
None
else
obj
.
logprob_start_len
[
index
]
obj
.
logprob_start_len
if
not_use_index
else
obj
.
logprob_start_len
[
index
]
)
top_logprobs_num
=
(
obj
.
top_logprobs_num
if
index
is
None
else
obj
.
top_logprobs_num
[
index
]
obj
.
top_logprobs_num
if
not_use_index
else
obj
.
top_logprobs_num
[
index
]
)
else
:
if
isinstance
(
obj
.
text
,
list
):
...
...
@@ -224,7 +227,7 @@ class TokenizerManager:
async
def
_handle_batch_request
(
self
,
obj
:
GenerateReqInput
,
request
):
batch_size
=
obj
.
batch_size
parallel_sample_num
=
obj
.
sampling_params
[
0
].
get
(
"n"
,
1
)
parallel_sample_num
=
obj
.
parallel_sample_num
if
parallel_sample_num
!=
1
:
# Send prefill requests to cache the common input
...
...
@@ -241,7 +244,6 @@ class TokenizerManager:
obj
.
input_ids
=
input_id_result
elif
input_id_result
is
not
None
:
obj
.
input_ids
=
input_id_result
[
0
]
# First send out all requests
for
i
in
range
(
batch_size
):
for
j
in
range
(
parallel_sample_num
):
...
...
@@ -249,7 +251,7 @@ class TokenizerManager:
continue
index
=
i
*
parallel_sample_num
+
j
if
parallel_sample_num
!=
1
:
# Here when using parallel sampling we shoul consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1
# Here when using parallel sampling we shoul
d
consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1
index
+=
batch_size
-
1
-
i
rid
=
obj
.
rid
[
index
]
if
parallel_sample_num
==
1
:
...
...
python/sglang/srt/openai_api/adapter.py
View file @
084fa54d
...
...
@@ -18,10 +18,14 @@ limitations under the License.
import
asyncio
import
json
import
os
import
time
import
uuid
from
http
import
HTTPStatus
from
typing
import
Dict
,
List
,
Optional
from
fastapi
import
Request
from
fastapi
import
HTTPException
,
Request
,
UploadFile
from
fastapi.responses
import
JSONResponse
,
StreamingResponse
from
pydantic
import
ValidationError
from
sglang.srt.conversation
import
(
Conversation
,
...
...
@@ -32,6 +36,8 @@ from sglang.srt.conversation import (
)
from
sglang.srt.managers.io_struct
import
GenerateReqInput
from
sglang.srt.openai_api.protocol
import
(
BatchRequest
,
BatchResponse
,
ChatCompletionRequest
,
ChatCompletionResponse
,
ChatCompletionResponseChoice
,
...
...
@@ -45,6 +51,8 @@ from sglang.srt.openai_api.protocol import (
CompletionStreamResponse
,
DeltaMessage
,
ErrorResponse
,
FileRequest
,
FileResponse
,
LogProbs
,
UsageInfo
,
)
...
...
@@ -52,6 +60,24 @@ from sglang.srt.openai_api.protocol import (
chat_template_name
=
None
class
FileMetadata
:
def
__init__
(
self
,
filename
:
str
,
purpose
:
str
):
self
.
filename
=
filename
self
.
purpose
=
purpose
# In-memory storage for batch jobs and files
batch_storage
:
Dict
[
str
,
BatchResponse
]
=
{}
file_id_request
:
Dict
[
str
,
FileMetadata
]
=
{}
file_id_response
:
Dict
[
str
,
FileResponse
]
=
{}
## map file id to file path in SGlang backend
file_id_storage
:
Dict
[
str
,
str
]
=
{}
# backend storage directory
storage_dir
=
None
def
create_error_response
(
message
:
str
,
err_type
:
str
=
"BadRequestError"
,
...
...
@@ -106,33 +132,364 @@ def load_chat_template_for_openai_api(chat_template_arg):
chat_template_name
=
chat_template_arg
async
def
v1_completions
(
tokenizer_manager
,
raw_request
:
Request
):
request_json
=
await
raw_request
.
json
()
request
=
CompletionRequest
(
**
request_json
)
prompt
=
request
.
prompt
if
isinstance
(
prompt
,
str
)
or
isinstance
(
prompt
[
0
],
str
):
prompt_kwargs
=
{
"text"
:
prompt
}
async
def
v1_files_create
(
file
:
UploadFile
,
purpose
:
str
,
file_storage_pth
:
str
=
None
):
try
:
global
storage_dir
if
file_storage_pth
:
storage_dir
=
file_storage_pth
# Read the file content
file_content
=
await
file
.
read
()
# Create an instance of RequestBody
request_body
=
FileRequest
(
file
=
file_content
,
purpose
=
purpose
)
# Save the file to the sglang_oai_storage directory
os
.
makedirs
(
storage_dir
,
exist_ok
=
True
)
file_id
=
f
"backend_input_file-
{
uuid
.
uuid4
()
}
"
filename
=
f
"
{
file_id
}
.jsonl"
file_path
=
os
.
path
.
join
(
storage_dir
,
filename
)
with
open
(
file_path
,
"wb"
)
as
f
:
f
.
write
(
request_body
.
file
)
# add info to global file map
file_id_request
[
file_id
]
=
FileMetadata
(
filename
=
file
.
filename
,
purpose
=
purpose
)
file_id_storage
[
file_id
]
=
file_path
# Return the response in the required format
response
=
FileResponse
(
id
=
file_id
,
bytes
=
len
(
request_body
.
file
),
created_at
=
int
(
time
.
time
()),
filename
=
file
.
filename
,
purpose
=
request_body
.
purpose
,
)
file_id_response
[
file_id
]
=
response
return
response
except
ValidationError
as
e
:
return
{
"error"
:
"Invalid input"
,
"details"
:
e
.
errors
()}
async
def
v1_batches
(
tokenizer_manager
,
raw_request
:
Request
):
try
:
body
=
await
raw_request
.
json
()
batch_request
=
BatchRequest
(
**
body
)
batch_id
=
f
"batch_
{
uuid
.
uuid4
()
}
"
# Create an instance of BatchResponse
batch_response
=
BatchResponse
(
id
=
batch_id
,
endpoint
=
batch_request
.
endpoint
,
input_file_id
=
batch_request
.
input_file_id
,
completion_window
=
batch_request
.
completion_window
,
created_at
=
int
(
time
.
time
()),
metadata
=
batch_request
.
metadata
,
)
batch_storage
[
batch_id
]
=
batch_response
# Start processing the batch asynchronously
asyncio
.
create_task
(
process_batch
(
tokenizer_manager
,
batch_id
,
batch_request
))
# Return the initial batch_response
return
batch_response
except
ValidationError
as
e
:
return
{
"error"
:
"Invalid input"
,
"details"
:
e
.
errors
()}
except
Exception
as
e
:
return
{
"error"
:
str
(
e
)}
async
def
process_batch
(
tokenizer_manager
,
batch_id
:
str
,
batch_request
:
BatchRequest
):
try
:
# Update the batch status to "in_progress"
batch_storage
[
batch_id
].
status
=
"in_progress"
batch_storage
[
batch_id
].
in_progress_at
=
int
(
time
.
time
())
# Retrieve the input file content
input_file_request
=
file_id_request
.
get
(
batch_request
.
input_file_id
)
if
not
input_file_request
:
raise
ValueError
(
"Input file not found"
)
# Parse the JSONL file and process each request
input_file_path
=
file_id_storage
.
get
(
batch_request
.
input_file_id
)
with
open
(
input_file_path
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
lines
=
f
.
readlines
()
total_requests
=
len
(
lines
)
completed_requests
=
0
failed_requests
=
0
all_ret
=
[]
end_point
=
batch_storage
[
batch_id
].
endpoint
file_request_list
=
[]
all_requests
=
[]
for
line
in
lines
:
request_data
=
json
.
loads
(
line
)
file_request_list
.
append
(
request_data
)
body
=
request_data
[
"body"
]
if
end_point
==
"/v1/chat/completions"
:
all_requests
.
append
(
ChatCompletionRequest
(
**
body
))
elif
end_point
==
"/v1/completions"
:
all_requests
.
append
(
CompletionRequest
(
**
body
))
if
end_point
==
"/v1/chat/completions"
:
adapted_request
,
request
=
v1_chat_generate_request
(
all_requests
,
tokenizer_manager
)
elif
end_point
==
"/v1/completions"
:
adapted_request
,
request
=
v1_generate_request
(
all_requests
)
try
:
ret
=
await
tokenizer_manager
.
generate_request
(
adapted_request
).
__anext__
()
if
not
isinstance
(
ret
,
list
):
ret
=
[
ret
]
if
end_point
==
"/v1/chat/completions"
:
responses
=
v1_chat_generate_response
(
request
,
ret
,
to_file
=
True
)
else
:
responses
=
v1_generate_response
(
request
,
ret
,
to_file
=
True
)
except
Exception
as
e
:
error_json
=
{
"id"
:
f
"batch_req_
{
uuid
.
uuid4
()
}
"
,
"custom_id"
:
request_data
.
get
(
"custom_id"
),
"response"
:
None
,
"error"
:
{
"message"
:
str
(
e
)},
}
all_ret
.
append
(
error_json
)
failed_requests
+=
len
(
file_request_list
)
for
idx
,
response
in
enumerate
(
responses
):
## the batch_req here can be changed to be named within a batch granularity
response_json
=
{
"id"
:
f
"batch_req_
{
uuid
.
uuid4
()
}
"
,
"custom_id"
:
file_request_list
[
idx
].
get
(
"custom_id"
),
"response"
:
response
,
"error"
:
None
,
}
all_ret
.
append
(
response_json
)
completed_requests
+=
1
# Write results to a new file
output_file_id
=
f
"backend_result_file-
{
uuid
.
uuid4
()
}
"
global
storage_dir
output_file_path
=
os
.
path
.
join
(
storage_dir
,
f
"
{
output_file_id
}
.jsonl"
)
with
open
(
output_file_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
for
ret
in
all_ret
:
f
.
write
(
json
.
dumps
(
ret
)
+
"
\n
"
)
# Update batch response with output file information
retrieve_batch
=
batch_storage
[
batch_id
]
retrieve_batch
.
output_file_id
=
output_file_id
file_id_storage
[
output_file_id
]
=
output_file_path
# Update batch status to "completed"
retrieve_batch
.
status
=
"completed"
retrieve_batch
.
completed_at
=
int
(
time
.
time
())
retrieve_batch
.
request_counts
=
{
"total"
:
total_requests
,
"completed"
:
completed_requests
,
"failed"
:
failed_requests
,
}
except
Exception
as
e
:
print
(
"error in SGlang:"
,
e
)
# Update batch status to "failed"
retrieve_batch
=
batch_storage
[
batch_id
]
retrieve_batch
.
status
=
"failed"
retrieve_batch
.
failed_at
=
int
(
time
.
time
())
retrieve_batch
.
errors
=
{
"message"
:
str
(
e
)}
async
def
v1_retrieve_batch
(
batch_id
:
str
):
# Retrieve the batch job from the in-memory storage
batch_response
=
batch_storage
.
get
(
batch_id
)
if
batch_response
is
None
:
raise
HTTPException
(
status_code
=
404
,
detail
=
"Batch not found"
)
return
batch_response
async
def
v1_retrieve_file
(
file_id
:
str
):
# Retrieve the batch job from the in-memory storage
file_response
=
file_id_response
.
get
(
file_id
)
if
file_response
is
None
:
raise
HTTPException
(
status_code
=
404
,
detail
=
"File not found"
)
return
file_response
async
def
v1_retrieve_file_content
(
file_id
:
str
):
file_pth
=
file_id_storage
.
get
(
file_id
)
if
not
file_pth
or
not
os
.
path
.
exists
(
file_pth
):
raise
HTTPException
(
status_code
=
404
,
detail
=
"File not found"
)
def
iter_file
():
with
open
(
file_pth
,
mode
=
"rb"
)
as
file_like
:
yield
from
file_like
return
StreamingResponse
(
iter_file
(),
media_type
=
"application/octet-stream"
)
def
v1_generate_request
(
all_requests
):
prompts
=
[]
sampling_params_list
=
[]
first_prompt_type
=
type
(
all_requests
[
0
].
prompt
)
for
request
in
all_requests
:
prompt
=
request
.
prompt
assert
(
type
(
prompt
)
==
first_prompt_type
),
"All prompts must be of the same type in file input settings"
prompts
.
append
(
prompt
)
sampling_params_list
.
append
(
{
"temperature"
:
request
.
temperature
,
"max_new_tokens"
:
request
.
max_tokens
,
"stop"
:
request
.
stop
,
"top_p"
:
request
.
top_p
,
"presence_penalty"
:
request
.
presence_penalty
,
"frequency_penalty"
:
request
.
frequency_penalty
,
"regex"
:
request
.
regex
,
"n"
:
request
.
n
,
"ignore_eos"
:
request
.
ignore_eos
,
}
)
if
len
(
all_requests
)
>
1
and
request
.
n
>
1
:
raise
ValueError
(
"Batch operation is not supported for completions from files"
)
if
len
(
all_requests
)
==
1
:
prompt
=
prompts
[
0
]
sampling_params_list
=
sampling_params_list
[
0
]
if
isinstance
(
prompts
,
str
)
or
isinstance
(
prompts
[
0
],
str
):
prompt_kwargs
=
{
"text"
:
prompt
}
else
:
prompt_kwargs
=
{
"input_ids"
:
prompt
}
else
:
prompt_kwargs
=
{
"input_ids"
:
prompt
}
if
isinstance
(
prompts
[
0
],
str
):
prompt_kwargs
=
{
"text"
:
prompts
}
else
:
prompt_kwargs
=
{
"input_ids"
:
prompts
}
adapted_request
=
GenerateReqInput
(
**
prompt_kwargs
,
sampling_params
=
{
"temperature"
:
request
.
temperature
,
"max_new_tokens"
:
request
.
max_tokens
,
"stop"
:
request
.
stop
,
"top_p"
:
request
.
top_p
,
"presence_penalty"
:
request
.
presence_penalty
,
"frequency_penalty"
:
request
.
frequency_penalty
,
"regex"
:
request
.
regex
,
"n"
:
request
.
n
,
"ignore_eos"
:
request
.
ignore_eos
,
},
return_logprob
=
request
.
logprobs
is
not
None
and
request
.
logprobs
>
0
,
top_logprobs_num
=
request
.
logprobs
if
request
.
logprobs
is
not
None
else
0
,
sampling_params
=
sampling_params_list
,
return_logprob
=
all_requests
[
0
].
logprobs
is
not
None
and
all_requests
[
0
].
logprobs
>
0
,
top_logprobs_num
=
(
all_requests
[
0
].
logprobs
if
all_requests
[
0
].
logprobs
is
not
None
else
0
),
return_text_in_logprobs
=
True
,
stream
=
request
.
stream
,
stream
=
all_
request
s
[
0
]
.
stream
,
)
if
len
(
all_requests
)
==
1
:
return
adapted_request
,
all_requests
[
0
]
return
adapted_request
,
all_requests
def
v1_generate_response
(
request
,
ret
,
to_file
=
False
):
choices
=
[]
echo
=
False
if
(
not
isinstance
(
request
,
List
))
and
request
.
echo
:
# TODO: handle the case propmt is token ids
if
isinstance
(
request
.
prompt
,
list
):
prompts
=
request
.
prompt
else
:
prompts
=
[
request
.
prompt
]
echo
=
True
for
idx
,
ret_item
in
enumerate
(
ret
):
text
=
ret_item
[
"text"
]
if
isinstance
(
request
,
List
)
and
request
[
idx
].
echo
:
echo
=
True
text
=
request
[
idx
].
prompt
+
text
if
(
not
isinstance
(
request
,
List
))
and
echo
:
text
=
prompts
[
idx
]
+
text
logprobs
=
False
if
isinstance
(
request
,
List
)
and
request
[
idx
].
logprobs
:
logprobs
=
True
elif
(
not
isinstance
(
request
,
List
))
and
request
.
logprobs
:
logprobs
=
True
if
logprobs
:
if
echo
:
input_token_logprobs
=
ret_item
[
"meta_info"
][
"input_token_logprobs"
]
input_top_logprobs
=
ret_item
[
"meta_info"
][
"input_top_logprobs"
]
else
:
input_token_logprobs
=
None
input_top_logprobs
=
None
logprobs
=
to_openai_style_logprobs
(
input_token_logprobs
=
input_token_logprobs
,
input_top_logprobs
=
input_top_logprobs
,
output_token_logprobs
=
ret_item
[
"meta_info"
][
"output_token_logprobs"
],
output_top_logprobs
=
ret_item
[
"meta_info"
][
"output_top_logprobs"
],
)
else
:
logprobs
=
None
if
to_file
:
## to make the choise data json serializable
choice_data
=
{
"index"
:
0
,
"text"
:
text
,
"logprobs"
:
logprobs
,
"finish_reason"
:
ret_item
[
"meta_info"
][
"finish_reason"
],
}
else
:
choice_data
=
CompletionResponseChoice
(
index
=
idx
,
text
=
text
,
logprobs
=
logprobs
,
finish_reason
=
ret_item
[
"meta_info"
][
"finish_reason"
],
)
choices
.
append
(
choice_data
)
if
to_file
:
responses
=
[]
for
i
,
choice
in
enumerate
(
choices
):
response
=
{
"status_code"
:
200
,
"request_id"
:
ret
[
i
][
"meta_info"
][
"id"
],
"body"
:
{
## remain the same but if needed we can change that
"id"
:
ret
[
i
][
"meta_info"
][
"id"
],
"object"
:
"text_completion"
,
"created"
:
int
(
time
.
time
()),
"model"
:
request
[
i
].
model
,
"choices"
:
choice
,
"usage"
:
{
"prompt_tokens"
:
ret
[
i
][
"meta_info"
][
"prompt_tokens"
],
"completion_tokens"
:
ret
[
i
][
"meta_info"
][
"completion_tokens"
],
"total_tokens"
:
ret
[
i
][
"meta_info"
][
"prompt_tokens"
]
+
ret
[
i
][
"meta_info"
][
"completion_tokens"
],
},
"system_fingerprint"
:
None
,
},
}
responses
.
append
(
response
)
return
responses
else
:
completion_tokens
=
sum
(
item
[
"meta_info"
][
"completion_tokens"
]
for
item
in
ret
)
response
=
CompletionResponse
(
id
=
ret
[
0
][
"meta_info"
][
"id"
],
model
=
request
.
model
,
choices
=
choices
,
usage
=
UsageInfo
(
prompt_tokens
=
ret
[
0
][
"meta_info"
][
"prompt_tokens"
],
completion_tokens
=
completion_tokens
,
total_tokens
=
ret
[
0
][
"meta_info"
][
"prompt_tokens"
]
+
completion_tokens
,
),
)
return
response
async
def
v1_completions
(
tokenizer_manager
,
raw_request
:
Request
):
request_json
=
await
raw_request
.
json
()
all_requests
=
[
CompletionRequest
(
**
request_json
)]
adapted_request
,
request
=
v1_generate_request
(
all_requests
)
if
adapted_request
.
stream
:
...
...
@@ -223,109 +580,144 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
if
not
isinstance
(
ret
,
list
):
ret
=
[
ret
]
if
request
.
echo
:
# TODO: handle the case propmt is token ids
if
isinstance
(
request
.
prompt
,
list
):
prompts
=
request
.
prompt
else
:
prompts
=
[
request
.
prompt
]
choices
=
[]
for
idx
,
ret_item
in
enumerate
(
ret
)
:
text
=
ret_item
[
"text"
]
response
=
v1_generate_response
(
request
,
ret
)
return
response
if
request
.
echo
:
text
=
prompts
[
idx
]
+
text
if
request
.
logprobs
:
if
request
.
echo
:
input_token_logprobs
=
ret_item
[
"meta_info"
][
"input_token_logprobs"
]
input_top_logprobs
=
ret_item
[
"meta_info"
][
"input_top_logprobs"
]
def
v1_chat_generate_request
(
all_requests
,
tokenizer_manager
):
texts
=
[]
sampling_params_list
=
[]
image_data_list
=
[]
for
request
in
all_requests
:
# Prep the data needed for the underlying GenerateReqInput:
# - prompt: The full prompt string.
# - stop: Custom stop tokens.
# - image_data: None or a list of image strings (URLs or base64 strings).
# None skips any image processing in GenerateReqInput.
if
not
isinstance
(
request
.
messages
,
str
):
# Apply chat template and its stop strings.
if
chat_template_name
is
None
:
prompt
=
tokenizer_manager
.
tokenizer
.
apply_chat_template
(
request
.
messages
,
tokenize
=
False
,
add_generation_prompt
=
True
)
stop
=
request
.
stop
image_data
=
None
else
:
input_token_logprobs
=
None
input_top_logprobs
=
N
on
e
logprobs
=
to_openai_style_logprobs
(
i
nput_token_logprobs
=
input_token_logprobs
,
input_top_logprobs
=
input_top_logprobs
,
output_token_logprobs
=
ret_item
[
"meta_info"
][
"output_token_logprobs"
],
output_top_logprobs
=
ret_item
[
"meta_info"
][
"output_top_logprobs"
],
)
conv
=
generate_chat_conv
(
request
,
chat_template_name
)
prompt
=
c
on
v
.
get_prompt
()
image_data
=
conv
.
image_data
stop
=
conv
.
stop_str
or
[]
i
f
request
.
stop
:
if
isinstance
(
request
.
stop
,
str
):
stop
.
append
(
request
.
stop
)
else
:
stop
.
extend
(
request
.
stop
)
else
:
logprobs
=
None
choice_data
=
CompletionResponseChoice
(
index
=
idx
,
text
=
text
,
logprobs
=
logprobs
,
finish_reason
=
ret_item
[
"meta_info"
][
"finish_reason"
],
# Use the raw prompt and stop strings if the messages is already a string.
prompt
=
request
.
messages
stop
=
request
.
stop
image_data
=
None
texts
.
append
(
prompt
)
sampling_params_list
.
append
(
{
"temperature"
:
request
.
temperature
,
"max_new_tokens"
:
request
.
max_tokens
,
"stop"
:
stop
,
"top_p"
:
request
.
top_p
,
"presence_penalty"
:
request
.
presence_penalty
,
"frequency_penalty"
:
request
.
frequency_penalty
,
"regex"
:
request
.
regex
,
"n"
:
request
.
n
,
}
)
choices
.
append
(
choice_data
)
completion_tokens
=
sum
(
item
[
"meta_info"
][
"completion_tokens"
]
for
item
in
ret
)
response
=
CompletionResponse
(
id
=
ret
[
0
][
"meta_info"
][
"id"
],
model
=
request
.
model
,
choices
=
choices
,
usage
=
UsageInfo
(
prompt_tokens
=
ret
[
0
][
"meta_info"
][
"prompt_tokens"
],
completion_tokens
=
completion_tokens
,
total_tokens
=
ret
[
0
][
"meta_info"
][
"prompt_tokens"
]
+
completion_tokens
,
),
image_data_list
.
append
(
image_data
)
if
len
(
all_requests
)
==
1
:
texts
=
texts
[
0
]
sampling_params_list
=
sampling_params_list
[
0
]
image_data
=
image_data_list
[
0
]
adapted_request
=
GenerateReqInput
(
text
=
texts
,
image_data
=
image_data
,
sampling_params
=
sampling_params_list
,
stream
=
request
.
stream
,
)
if
len
(
all_requests
)
==
1
:
return
adapted_request
,
all_requests
[
0
]
return
adapted_request
,
all_requests
return
response
def
v1_chat_generate_response
(
request
,
ret
,
to_file
=
False
):
choices
=
[]
total_prompt_tokens
=
0
total_completion_tokens
=
0
async
def
v1_chat_completions
(
tokenizer_manager
,
raw_request
:
Request
):
request_json
=
await
raw_request
.
json
()
request
=
ChatCompletionRequest
(
**
request_json
)
# Prep the data needed for the underlying GenerateReqInput:
# - prompt: The full prompt string.
# - stop: Custom stop tokens.
# - image_data: None or a list of image strings (URLs or base64 strings).
# None skips any image processing in GenerateReqInput.
if
not
isinstance
(
request
.
messages
,
str
):
# Apply chat template and its stop strings.
if
chat_template_name
is
None
:
prompt
=
tokenizer_manager
.
tokenizer
.
apply_chat_template
(
request
.
messages
,
tokenize
=
False
,
add_generation_prompt
=
True
)
stop
=
request
.
stop
image_data
=
None
for
idx
,
ret_item
in
enumerate
(
ret
):
prompt_tokens
=
ret_item
[
"meta_info"
][
"prompt_tokens"
]
completion_tokens
=
ret_item
[
"meta_info"
][
"completion_tokens"
]
if
to_file
:
## to make the choice data json serializable
choice_data
=
{
"index"
:
0
,
"message"
:
{
"role"
:
"assistant"
,
"content"
:
ret_item
[
"text"
]},
"logprobs"
:
None
,
"finish_reason"
:
ret_item
[
"meta_info"
][
"finish_reason"
],
}
else
:
conv
=
generate_chat_conv
(
request
,
chat_template_name
)
prompt
=
conv
.
get_prompt
()
image_data
=
conv
.
image_data
stop
=
conv
.
stop_str
or
[]
if
request
.
stop
:
if
isinstance
(
request
.
stop
,
str
):
stop
.
append
(
request
.
stop
)
else
:
stop
.
extend
(
request
.
stop
)
choice_data
=
ChatCompletionResponseChoice
(
index
=
idx
,
message
=
ChatMessage
(
role
=
"assistant"
,
content
=
ret_item
[
"text"
]),
finish_reason
=
ret_item
[
"meta_info"
][
"finish_reason"
],
)
choices
.
append
(
choice_data
)
total_prompt_tokens
=
prompt_tokens
total_completion_tokens
+=
completion_tokens
if
to_file
:
responses
=
[]
for
i
,
choice
in
enumerate
(
choices
):
response
=
{
"status_code"
:
200
,
"request_id"
:
ret
[
i
][
"meta_info"
][
"id"
],
"body"
:
{
## remain the same but if needed we can change that
"id"
:
ret
[
i
][
"meta_info"
][
"id"
],
"object"
:
"chat.completion"
,
"created"
:
int
(
time
.
time
()),
"model"
:
request
[
i
].
model
,
"choices"
:
choice
,
"usage"
:
{
"prompt_tokens"
:
ret
[
i
][
"meta_info"
][
"prompt_tokens"
],
"completion_tokens"
:
ret
[
i
][
"meta_info"
][
"completion_tokens"
],
"total_tokens"
:
ret
[
i
][
"meta_info"
][
"prompt_tokens"
]
+
ret
[
i
][
"meta_info"
][
"completion_tokens"
],
},
"system_fingerprint"
:
None
,
},
}
responses
.
append
(
response
)
return
responses
else
:
# Use the raw prompt and stop strings if the messages is already a string.
prompt
=
request
.
messages
stop
=
request
.
stop
image_data
=
None
response
=
ChatCompletionResponse
(
id
=
ret
[
0
][
"meta_info"
][
"id"
],
model
=
request
.
model
,
choices
=
choices
,
usage
=
UsageInfo
(
prompt_tokens
=
total_prompt_tokens
,
completion_tokens
=
total_completion_tokens
,
total_tokens
=
total_prompt_tokens
+
total_completion_tokens
,
),
)
return
response
adapted_request
=
GenerateReqInput
(
text
=
prompt
,
image_data
=
image_data
,
sampling_params
=
{
"temperature"
:
request
.
temperature
,
"max_new_tokens"
:
request
.
max_tokens
,
"stop"
:
stop
,
"top_p"
:
request
.
top_p
,
"presence_penalty"
:
request
.
presence_penalty
,
"frequency_penalty"
:
request
.
frequency_penalty
,
"regex"
:
request
.
regex
,
"n"
:
request
.
n
,
},
stream
=
request
.
stream
,
)
async
def
v1_chat_completions
(
tokenizer_manager
,
raw_request
:
Request
):
request_json
=
await
raw_request
.
json
()
all_requests
=
[
ChatCompletionRequest
(
**
request_json
)]
adapted_request
,
request
=
v1_chat_generate_request
(
all_requests
,
tokenizer_manager
)
if
adapted_request
.
stream
:
...
...
@@ -387,34 +779,8 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
if
not
isinstance
(
ret
,
list
):
ret
=
[
ret
]
choices
=
[]
total_prompt_tokens
=
0
total_completion_tokens
=
0
for
idx
,
ret_item
in
enumerate
(
ret
):
prompt_tokens
=
ret_item
[
"meta_info"
][
"prompt_tokens"
]
completion_tokens
=
ret_item
[
"meta_info"
][
"completion_tokens"
]
choice_data
=
ChatCompletionResponseChoice
(
index
=
idx
,
message
=
ChatMessage
(
role
=
"assistant"
,
content
=
ret_item
[
"text"
]),
finish_reason
=
ret_item
[
"meta_info"
][
"finish_reason"
],
)
choices
.
append
(
choice_data
)
total_prompt_tokens
=
prompt_tokens
total_completion_tokens
+=
completion_tokens
response
=
ChatCompletionResponse
(
id
=
ret
[
0
][
"meta_info"
][
"id"
],
model
=
request
.
model
,
choices
=
choices
,
usage
=
UsageInfo
(
prompt_tokens
=
total_prompt_tokens
,
completion_tokens
=
total_completion_tokens
,
total_tokens
=
total_prompt_tokens
+
total_completion_tokens
,
),
)
response
=
v1_chat_generate_response
(
request
,
ret
)
return
response
...
...
python/sglang/srt/openai_api/protocol.py
View file @
084fa54d
...
...
@@ -60,6 +60,55 @@ class UsageInfo(BaseModel):
completion_tokens
:
Optional
[
int
]
=
0
class
FileRequest
(
BaseModel
):
# https://platform.openai.com/docs/api-reference/files/create
file
:
bytes
# The File object (not file name) to be uploaded
purpose
:
str
=
(
"batch"
# The intended purpose of the uploaded file, default is "batch"
)
class
FileResponse
(
BaseModel
):
id
:
str
object
:
str
=
"file"
bytes
:
int
created_at
:
int
filename
:
str
purpose
:
str
class
BatchRequest
(
BaseModel
):
input_file_id
:
(
str
# The ID of an uploaded file that contains requests for the new batch
)
endpoint
:
str
# The endpoint to be used for all requests in the batch
completion_window
:
str
# The time frame within which the batch should be processed
metadata
:
Optional
[
dict
]
=
None
# Optional custom metadata for the batch
class
BatchResponse
(
BaseModel
):
id
:
str
object
:
str
=
"batch"
endpoint
:
str
errors
:
Optional
[
dict
]
=
None
input_file_id
:
str
completion_window
:
str
status
:
str
=
"validating"
output_file_id
:
Optional
[
str
]
=
None
error_file_id
:
Optional
[
str
]
=
None
created_at
:
int
in_progress_at
:
Optional
[
int
]
=
None
expires_at
:
Optional
[
int
]
=
None
finalizing_at
:
Optional
[
int
]
=
None
completed_at
:
Optional
[
int
]
=
None
failed_at
:
Optional
[
int
]
=
None
expired_at
:
Optional
[
int
]
=
None
cancelling_at
:
Optional
[
int
]
=
None
cancelled_at
:
Optional
[
int
]
=
None
request_counts
:
dict
=
{
"total"
:
0
,
"completed"
:
0
,
"failed"
:
0
}
metadata
:
Optional
[
dict
]
=
None
class
CompletionRequest
(
BaseModel
):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/completions/create
...
...
python/sglang/srt/server.py
View file @
084fa54d
...
...
@@ -38,7 +38,7 @@ import psutil
import
requests
import
uvicorn
import
uvloop
from
fastapi
import
FastAPI
,
Request
from
fastapi
import
FastAPI
,
File
,
Form
,
Request
,
UploadFile
from
fastapi.responses
import
JSONResponse
,
Response
,
StreamingResponse
from
sglang.lang.backend.runtime_endpoint
import
RuntimeEndpoint
...
...
@@ -56,8 +56,13 @@ from sglang.srt.managers.io_struct import GenerateReqInput
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
from
sglang.srt.openai_api.adapter
import
(
load_chat_template_for_openai_api
,
v1_batches
,
v1_chat_completions
,
v1_completions
,
v1_files_create
,
v1_retrieve_batch
,
v1_retrieve_file
,
v1_retrieve_file_content
,
)
from
sglang.srt.openai_api.protocol
import
ModelCard
,
ModelList
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
...
...
@@ -152,6 +157,35 @@ async def openai_v1_chat_completions(raw_request: Request):
return
await
v1_chat_completions
(
tokenizer_manager
,
raw_request
)
@
app
.
post
(
"/v1/files"
)
async
def
openai_v1_files
(
file
:
UploadFile
=
File
(...),
purpose
:
str
=
Form
(
"batch"
)):
return
await
v1_files_create
(
file
,
purpose
,
tokenizer_manager
.
server_args
.
file_storage_pth
)
@
app
.
post
(
"/v1/batches"
)
async
def
openai_v1_batches
(
raw_request
:
Request
):
return
await
v1_batches
(
tokenizer_manager
,
raw_request
)
@
app
.
get
(
"/v1/batches/{batch_id}"
)
async
def
retrieve_batch
(
batch_id
:
str
):
return
await
v1_retrieve_batch
(
batch_id
)
@
app
.
get
(
"/v1/files/{file_id}"
)
async
def
retrieve_file
(
file_id
:
str
):
# https://platform.openai.com/docs/api-reference/files/retrieve
return
await
v1_retrieve_file
(
file_id
)
@
app
.
get
(
"/v1/files/{file_id}/content"
)
async
def
retrieve_file_content
(
file_id
:
str
):
# https://platform.openai.com/docs/api-reference/files/retrieve-contents
return
await
v1_retrieve_file_content
(
file_id
)
@
app
.
get
(
"/v1/models"
)
def
available_models
():
"""Show available models."""
...
...
python/sglang/srt/server_args.py
View file @
084fa54d
...
...
@@ -60,6 +60,7 @@ class ServerArgs:
# Other
api_key
:
str
=
""
file_storage_pth
:
str
=
"SGlang_storage"
# Data parallelism
dp_size
:
int
=
1
...
...
@@ -290,6 +291,12 @@ class ServerArgs:
default
=
ServerArgs
.
api_key
,
help
=
"Set API key of the server."
,
)
parser
.
add_argument
(
"--file-storage-pth"
,
type
=
str
,
default
=
ServerArgs
.
file_storage_pth
,
help
=
"The path of the file storage in backend."
,
)
# Data parallelism
parser
.
add_argument
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment