Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
084fa54d
Unverified
Commit
084fa54d
authored
Jul 30, 2024
by
yichuan~
Committed by
GitHub
Jul 29, 2024
Browse files
Add support for OpenAI API : offline batch(file) processing (#699)
Co-authored-by:
hnyls2002
<
hnyls2002@gmail.com
>
parent
eba458bd
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
839 additions
and
154 deletions
+839
-154
.pre-commit-config.yaml
.pre-commit-config.yaml
+1
-1
examples/usage/openai_batch_chat.py
examples/usage/openai_batch_chat.py
+86
-0
examples/usage/openai_batch_complete.py
examples/usage/openai_batch_complete.py
+86
-0
examples/usage/openai_parallel_sample.py
examples/usage/openai_parallel_sample.py
+37
-0
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+20
-2
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+13
-11
python/sglang/srt/openai_api/adapter.py
python/sglang/srt/openai_api/adapter.py
+505
-139
python/sglang/srt/openai_api/protocol.py
python/sglang/srt/openai_api/protocol.py
+49
-0
python/sglang/srt/server.py
python/sglang/srt/server.py
+35
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+7
-0
No files found.
.pre-commit-config.yaml
View file @
084fa54d
...
@@ -4,6 +4,6 @@ repos:
...
@@ -4,6 +4,6 @@ repos:
hooks
:
hooks
:
-
id
:
isort
-
id
:
isort
-
repo
:
https://github.com/psf/black
-
repo
:
https://github.com/psf/black
rev
:
stable
rev
:
24.4.2
hooks
:
hooks
:
-
id
:
black
-
id
:
black
examples/usage/openai_batch_chat.py
0 → 100644
View file @
084fa54d
import
json
import
os
import
time
import
openai
from
openai
import
OpenAI
class
OpenAIBatchProcessor
:
def
__init__
(
self
,
api_key
):
# client = OpenAI(api_key=api_key)
client
=
openai
.
Client
(
base_url
=
"http://127.0.0.1:30000/v1"
,
api_key
=
"EMPTY"
)
self
.
client
=
client
def
process_batch
(
self
,
input_file_path
,
endpoint
,
completion_window
):
# Upload the input file
with
open
(
input_file_path
,
"rb"
)
as
file
:
uploaded_file
=
self
.
client
.
files
.
create
(
file
=
file
,
purpose
=
"batch"
)
# Create the batch job
batch_job
=
self
.
client
.
batches
.
create
(
input_file_id
=
uploaded_file
.
id
,
endpoint
=
endpoint
,
completion_window
=
completion_window
,
)
# Monitor the batch job status
while
batch_job
.
status
not
in
[
"completed"
,
"failed"
,
"cancelled"
]:
time
.
sleep
(
3
)
# Wait for 3 seconds before checking the status again
print
(
f
"Batch job status:
{
batch_job
.
status
}
...trying again in 3 seconds..."
)
batch_job
=
self
.
client
.
batches
.
retrieve
(
batch_job
.
id
)
# Check the batch job status and errors
if
batch_job
.
status
==
"failed"
:
print
(
f
"Batch job failed with status:
{
batch_job
.
status
}
"
)
print
(
f
"Batch job errors:
{
batch_job
.
errors
}
"
)
return
None
# If the batch job is completed, process the results
if
batch_job
.
status
==
"completed"
:
# print result of batch job
print
(
"batch"
,
batch_job
.
request_counts
)
result_file_id
=
batch_job
.
output_file_id
# Retrieve the file content from the server
file_response
=
self
.
client
.
files
.
content
(
result_file_id
)
result_content
=
file_response
.
read
()
# Read the content of the file
# Save the content to a local file
result_file_name
=
"batch_job_chat_results.jsonl"
with
open
(
result_file_name
,
"wb"
)
as
file
:
file
.
write
(
result_content
)
# Write the binary content to the file
# Load data from the saved JSONL file
results
=
[]
with
open
(
result_file_name
,
"r"
,
encoding
=
"utf-8"
)
as
file
:
for
line
in
file
:
json_object
=
json
.
loads
(
line
.
strip
()
)
# Parse each line as a JSON object
results
.
append
(
json_object
)
return
results
else
:
print
(
f
"Batch job failed with status:
{
batch_job
.
status
}
"
)
return
None
# Initialize the OpenAIBatchProcessor
api_key
=
os
.
environ
.
get
(
"OPENAI_API_KEY"
)
processor
=
OpenAIBatchProcessor
(
api_key
)
# Process the batch job
input_file_path
=
"input.jsonl"
endpoint
=
"/v1/chat/completions"
completion_window
=
"24h"
# Process the batch job
results
=
processor
.
process_batch
(
input_file_path
,
endpoint
,
completion_window
)
# Print the results
print
(
results
)
examples/usage/openai_batch_complete.py
0 → 100644
View file @
084fa54d
import
json
import
os
import
time
import
openai
from
openai
import
OpenAI
class
OpenAIBatchProcessor
:
def
__init__
(
self
,
api_key
):
# client = OpenAI(api_key=api_key)
client
=
openai
.
Client
(
base_url
=
"http://127.0.0.1:30000/v1"
,
api_key
=
"EMPTY"
)
self
.
client
=
client
def
process_batch
(
self
,
input_file_path
,
endpoint
,
completion_window
):
# Upload the input file
with
open
(
input_file_path
,
"rb"
)
as
file
:
uploaded_file
=
self
.
client
.
files
.
create
(
file
=
file
,
purpose
=
"batch"
)
# Create the batch job
batch_job
=
self
.
client
.
batches
.
create
(
input_file_id
=
uploaded_file
.
id
,
endpoint
=
endpoint
,
completion_window
=
completion_window
,
)
# Monitor the batch job status
while
batch_job
.
status
not
in
[
"completed"
,
"failed"
,
"cancelled"
]:
time
.
sleep
(
3
)
# Wait for 3 seconds before checking the status again
print
(
f
"Batch job status:
{
batch_job
.
status
}
...trying again in 3 seconds..."
)
batch_job
=
self
.
client
.
batches
.
retrieve
(
batch_job
.
id
)
# Check the batch job status and errors
if
batch_job
.
status
==
"failed"
:
print
(
f
"Batch job failed with status:
{
batch_job
.
status
}
"
)
print
(
f
"Batch job errors:
{
batch_job
.
errors
}
"
)
return
None
# If the batch job is completed, process the results
if
batch_job
.
status
==
"completed"
:
# print result of batch job
print
(
"batch"
,
batch_job
.
request_counts
)
result_file_id
=
batch_job
.
output_file_id
# Retrieve the file content from the server
file_response
=
self
.
client
.
files
.
content
(
result_file_id
)
result_content
=
file_response
.
read
()
# Read the content of the file
# Save the content to a local file
result_file_name
=
"batch_job_complete_results.jsonl"
with
open
(
result_file_name
,
"wb"
)
as
file
:
file
.
write
(
result_content
)
# Write the binary content to the file
# Load data from the saved JSONL file
results
=
[]
with
open
(
result_file_name
,
"r"
,
encoding
=
"utf-8"
)
as
file
:
for
line
in
file
:
json_object
=
json
.
loads
(
line
.
strip
()
)
# Parse each line as a JSON object
results
.
append
(
json_object
)
return
results
else
:
print
(
f
"Batch job failed with status:
{
batch_job
.
status
}
"
)
return
None
# Initialize the OpenAIBatchProcessor
api_key
=
os
.
environ
.
get
(
"OPENAI_API_KEY"
)
processor
=
OpenAIBatchProcessor
(
api_key
)
# Process the batch job
input_file_path
=
"input_complete.jsonl"
endpoint
=
"/v1/completions"
completion_window
=
"24h"
# Process the batch job
results
=
processor
.
process_batch
(
input_file_path
,
endpoint
,
completion_window
)
# Print the results
print
(
results
)
examples/usage/openai_parallel_sample.py
View file @
084fa54d
...
@@ -13,6 +13,17 @@ response = client.completions.create(
...
@@ -13,6 +13,17 @@ response = client.completions.create(
print
(
response
)
print
(
response
)
# Text completion
response
=
client
.
completions
.
create
(
model
=
"default"
,
prompt
=
"I am a robot and I want to study like humans. Now let's tell a story. Once upon a time, there was a little"
,
n
=
1
,
temperature
=
0.8
,
max_tokens
=
32
,
)
print
(
response
)
# Text completion
# Text completion
response
=
client
.
completions
.
create
(
response
=
client
.
completions
.
create
(
model
=
"default"
,
model
=
"default"
,
...
@@ -24,6 +35,17 @@ response = client.completions.create(
...
@@ -24,6 +35,17 @@ response = client.completions.create(
print
(
response
)
print
(
response
)
# Text completion
response
=
client
.
completions
.
create
(
model
=
"default"
,
prompt
=
[
"The name of the famous soccer player is"
],
n
=
1
,
temperature
=
0.8
,
max_tokens
=
128
,
)
print
(
response
)
# Text completion
# Text completion
response
=
client
.
completions
.
create
(
response
=
client
.
completions
.
create
(
model
=
"default"
,
model
=
"default"
,
...
@@ -60,6 +82,21 @@ response = client.completions.create(
...
@@ -60,6 +82,21 @@ response = client.completions.create(
)
)
print
(
response
)
print
(
response
)
# Chat completion
response
=
client
.
chat
.
completions
.
create
(
model
=
"default"
,
messages
=
[
{
"role"
:
"system"
,
"content"
:
"You are a helpful AI assistant"
},
{
"role"
:
"user"
,
"content"
:
"List 3 countries and their capitals."
},
],
temperature
=
0.8
,
max_tokens
=
64
,
logprobs
=
True
,
n
=
1
,
)
print
(
response
)
# Chat completion
# Chat completion
response
=
client
.
chat
.
completions
.
create
(
response
=
client
.
chat
.
completions
.
create
(
model
=
"default"
,
model
=
"default"
,
...
...
python/sglang/srt/managers/io_struct.py
View file @
084fa54d
...
@@ -79,8 +79,26 @@ class GenerateReqInput:
...
@@ -79,8 +79,26 @@ class GenerateReqInput:
if
self
.
top_logprobs_num
is
None
:
if
self
.
top_logprobs_num
is
None
:
self
.
top_logprobs_num
=
0
self
.
top_logprobs_num
=
0
else
:
else
:
parallel_sample_num_list
=
[]
if
isinstance
(
self
.
sampling_params
,
dict
):
parallel_sample_num
=
self
.
sampling_params
.
get
(
"n"
,
1
)
parallel_sample_num
=
self
.
sampling_params
.
get
(
"n"
,
1
)
elif
isinstance
(
self
.
sampling_params
,
list
):
for
sp
in
self
.
sampling_params
:
parallel_sample_num
=
sp
.
get
(
"n"
,
1
)
parallel_sample_num_list
.
append
(
parallel_sample_num
)
parallel_sample_num
=
max
(
parallel_sample_num_list
)
all_equal
=
all
(
element
==
parallel_sample_num
for
element
in
parallel_sample_num_list
)
if
parallel_sample_num
>
1
and
(
not
all_equal
):
## TODO cope with the case that the parallel_sample_num is different for different samples
raise
ValueError
(
"The parallel_sample_num should be the same for all samples in sample params."
)
else
:
parallel_sample_num
=
1
self
.
parallel_sample_num
=
parallel_sample_num
if
parallel_sample_num
!=
1
:
if
parallel_sample_num
!=
1
:
# parallel sampling +1 represents the original prefill stage
# parallel sampling +1 represents the original prefill stage
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
084fa54d
...
@@ -84,6 +84,7 @@ class TokenizerManager:
...
@@ -84,6 +84,7 @@ class TokenizerManager:
trust_remote_code
=
server_args
.
trust_remote_code
,
trust_remote_code
=
server_args
.
trust_remote_code
,
model_overide_args
=
model_overide_args
,
model_overide_args
=
model_overide_args
,
)
)
if
server_args
.
context_length
is
not
None
:
if
server_args
.
context_length
is
not
None
:
self
.
context_len
=
server_args
.
context_length
self
.
context_len
=
server_args
.
context_length
else
:
else
:
...
@@ -152,31 +153,33 @@ class TokenizerManager:
...
@@ -152,31 +153,33 @@ class TokenizerManager:
self
,
obj
,
request
,
index
=
None
,
is_cache_for_prefill
=
False
self
,
obj
,
request
,
index
=
None
,
is_cache_for_prefill
=
False
):
):
if
not
is_cache_for_prefill
:
if
not
is_cache_for_prefill
:
rid
=
obj
.
rid
if
index
is
None
else
obj
.
rid
[
index
]
not_use_index
=
not
(
index
is
not
None
)
input_text
=
obj
.
text
if
index
is
None
else
obj
.
text
[
index
]
rid
=
obj
.
rid
if
not_use_index
else
obj
.
rid
[
index
]
input_text
=
obj
.
text
if
not_use_index
else
obj
.
text
[
index
]
input_ids
=
(
input_ids
=
(
self
.
tokenizer
.
encode
(
input_text
)
self
.
tokenizer
.
encode
(
input_text
)
if
obj
.
input_ids
is
None
if
obj
.
input_ids
is
None
else
obj
.
input_ids
else
obj
.
input_ids
)
)
if
index
is
not
None
and
obj
.
input_ids
:
if
not
not_use_index
and
obj
.
input_ids
:
input_ids
=
obj
.
input_ids
[
index
]
input_ids
=
obj
.
input_ids
[
index
]
self
.
_validate_input_length
(
input_ids
)
self
.
_validate_input_length
(
input_ids
)
sampling_params
=
self
.
_get_sampling_params
(
sampling_params
=
self
.
_get_sampling_params
(
obj
.
sampling_params
if
index
is
None
else
obj
.
sampling_params
[
index
]
obj
.
sampling_params
if
not_use_index
else
obj
.
sampling_params
[
index
]
)
)
pixel_values
,
image_hash
,
image_size
=
await
self
.
_get_pixel_values
(
pixel_values
,
image_hash
,
image_size
=
await
self
.
_get_pixel_values
(
obj
.
image_data
if
index
is
None
else
obj
.
image_data
[
index
]
obj
.
image_data
if
not_use_index
else
obj
.
image_data
[
index
]
)
)
return_logprob
=
(
return_logprob
=
(
obj
.
return_logprob
if
index
is
None
else
obj
.
return_logprob
[
index
]
obj
.
return_logprob
if
not_use_index
else
obj
.
return_logprob
[
index
]
)
)
logprob_start_len
=
(
logprob_start_len
=
(
obj
.
logprob_start_len
if
index
is
None
else
obj
.
logprob_start_len
[
index
]
obj
.
logprob_start_len
if
not_use_index
else
obj
.
logprob_start_len
[
index
]
)
)
top_logprobs_num
=
(
top_logprobs_num
=
(
obj
.
top_logprobs_num
if
index
is
None
else
obj
.
top_logprobs_num
[
index
]
obj
.
top_logprobs_num
if
not_use_index
else
obj
.
top_logprobs_num
[
index
]
)
)
else
:
else
:
if
isinstance
(
obj
.
text
,
list
):
if
isinstance
(
obj
.
text
,
list
):
...
@@ -224,7 +227,7 @@ class TokenizerManager:
...
@@ -224,7 +227,7 @@ class TokenizerManager:
async
def
_handle_batch_request
(
self
,
obj
:
GenerateReqInput
,
request
):
async
def
_handle_batch_request
(
self
,
obj
:
GenerateReqInput
,
request
):
batch_size
=
obj
.
batch_size
batch_size
=
obj
.
batch_size
parallel_sample_num
=
obj
.
sampling_params
[
0
].
get
(
"n"
,
1
)
parallel_sample_num
=
obj
.
parallel_sample_num
if
parallel_sample_num
!=
1
:
if
parallel_sample_num
!=
1
:
# Send prefill requests to cache the common input
# Send prefill requests to cache the common input
...
@@ -241,7 +244,6 @@ class TokenizerManager:
...
@@ -241,7 +244,6 @@ class TokenizerManager:
obj
.
input_ids
=
input_id_result
obj
.
input_ids
=
input_id_result
elif
input_id_result
is
not
None
:
elif
input_id_result
is
not
None
:
obj
.
input_ids
=
input_id_result
[
0
]
obj
.
input_ids
=
input_id_result
[
0
]
# First send out all requests
# First send out all requests
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
for
j
in
range
(
parallel_sample_num
):
for
j
in
range
(
parallel_sample_num
):
...
@@ -249,7 +251,7 @@ class TokenizerManager:
...
@@ -249,7 +251,7 @@ class TokenizerManager:
continue
continue
index
=
i
*
parallel_sample_num
+
j
index
=
i
*
parallel_sample_num
+
j
if
parallel_sample_num
!=
1
:
if
parallel_sample_num
!=
1
:
# Here when using parallel sampling we shoul consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1
# Here when using parallel sampling we shoul
d
consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1
index
+=
batch_size
-
1
-
i
index
+=
batch_size
-
1
-
i
rid
=
obj
.
rid
[
index
]
rid
=
obj
.
rid
[
index
]
if
parallel_sample_num
==
1
:
if
parallel_sample_num
==
1
:
...
...
python/sglang/srt/openai_api/adapter.py
View file @
084fa54d
...
@@ -18,10 +18,14 @@ limitations under the License.
...
@@ -18,10 +18,14 @@ limitations under the License.
import
asyncio
import
asyncio
import
json
import
json
import
os
import
os
import
time
import
uuid
from
http
import
HTTPStatus
from
http
import
HTTPStatus
from
typing
import
Dict
,
List
,
Optional
from
fastapi
import
Request
from
fastapi
import
HTTPException
,
Request
,
UploadFile
from
fastapi.responses
import
JSONResponse
,
StreamingResponse
from
fastapi.responses
import
JSONResponse
,
StreamingResponse
from
pydantic
import
ValidationError
from
sglang.srt.conversation
import
(
from
sglang.srt.conversation
import
(
Conversation
,
Conversation
,
...
@@ -32,6 +36,8 @@ from sglang.srt.conversation import (
...
@@ -32,6 +36,8 @@ from sglang.srt.conversation import (
)
)
from
sglang.srt.managers.io_struct
import
GenerateReqInput
from
sglang.srt.managers.io_struct
import
GenerateReqInput
from
sglang.srt.openai_api.protocol
import
(
from
sglang.srt.openai_api.protocol
import
(
BatchRequest
,
BatchResponse
,
ChatCompletionRequest
,
ChatCompletionRequest
,
ChatCompletionResponse
,
ChatCompletionResponse
,
ChatCompletionResponseChoice
,
ChatCompletionResponseChoice
,
...
@@ -45,6 +51,8 @@ from sglang.srt.openai_api.protocol import (
...
@@ -45,6 +51,8 @@ from sglang.srt.openai_api.protocol import (
CompletionStreamResponse
,
CompletionStreamResponse
,
DeltaMessage
,
DeltaMessage
,
ErrorResponse
,
ErrorResponse
,
FileRequest
,
FileResponse
,
LogProbs
,
LogProbs
,
UsageInfo
,
UsageInfo
,
)
)
...
@@ -52,6 +60,24 @@ from sglang.srt.openai_api.protocol import (
...
@@ -52,6 +60,24 @@ from sglang.srt.openai_api.protocol import (
chat_template_name
=
None
chat_template_name
=
None
class
FileMetadata
:
def
__init__
(
self
,
filename
:
str
,
purpose
:
str
):
self
.
filename
=
filename
self
.
purpose
=
purpose
# In-memory storage for batch jobs and files
batch_storage
:
Dict
[
str
,
BatchResponse
]
=
{}
file_id_request
:
Dict
[
str
,
FileMetadata
]
=
{}
file_id_response
:
Dict
[
str
,
FileResponse
]
=
{}
## map file id to file path in SGlang backend
file_id_storage
:
Dict
[
str
,
str
]
=
{}
# backend storage directory
storage_dir
=
None
def
create_error_response
(
def
create_error_response
(
message
:
str
,
message
:
str
,
err_type
:
str
=
"BadRequestError"
,
err_type
:
str
=
"BadRequestError"
,
...
@@ -106,18 +132,216 @@ def load_chat_template_for_openai_api(chat_template_arg):
...
@@ -106,18 +132,216 @@ def load_chat_template_for_openai_api(chat_template_arg):
chat_template_name
=
chat_template_arg
chat_template_name
=
chat_template_arg
async
def
v1_completions
(
tokenizer_manager
,
raw_request
:
Request
):
async
def
v1_files_create
(
file
:
UploadFile
,
purpose
:
str
,
file_storage_pth
:
str
=
None
):
request_json
=
await
raw_request
.
json
()
try
:
request
=
CompletionRequest
(
**
request_json
)
global
storage_dir
prompt
=
request
.
prompt
if
file_storage_pth
:
if
isinstance
(
prompt
,
str
)
or
isinstance
(
prompt
[
0
],
str
):
storage_dir
=
file_storage_pth
prompt_kwargs
=
{
"text"
:
prompt
}
# Read the file content
else
:
file_content
=
await
file
.
read
()
prompt_kwargs
=
{
"input_ids"
:
prompt
}
# Create an instance of RequestBody
request_body
=
FileRequest
(
file
=
file_content
,
purpose
=
purpose
)
# Save the file to the sglang_oai_storage directory
os
.
makedirs
(
storage_dir
,
exist_ok
=
True
)
file_id
=
f
"backend_input_file-
{
uuid
.
uuid4
()
}
"
filename
=
f
"
{
file_id
}
.jsonl"
file_path
=
os
.
path
.
join
(
storage_dir
,
filename
)
with
open
(
file_path
,
"wb"
)
as
f
:
f
.
write
(
request_body
.
file
)
# add info to global file map
file_id_request
[
file_id
]
=
FileMetadata
(
filename
=
file
.
filename
,
purpose
=
purpose
)
file_id_storage
[
file_id
]
=
file_path
# Return the response in the required format
response
=
FileResponse
(
id
=
file_id
,
bytes
=
len
(
request_body
.
file
),
created_at
=
int
(
time
.
time
()),
filename
=
file
.
filename
,
purpose
=
request_body
.
purpose
,
)
file_id_response
[
file_id
]
=
response
adapted_request
=
GenerateReqInput
(
return
response
**
prompt_kwargs
,
except
ValidationError
as
e
:
sampling_params
=
{
return
{
"error"
:
"Invalid input"
,
"details"
:
e
.
errors
()}
async
def
v1_batches
(
tokenizer_manager
,
raw_request
:
Request
):
try
:
body
=
await
raw_request
.
json
()
batch_request
=
BatchRequest
(
**
body
)
batch_id
=
f
"batch_
{
uuid
.
uuid4
()
}
"
# Create an instance of BatchResponse
batch_response
=
BatchResponse
(
id
=
batch_id
,
endpoint
=
batch_request
.
endpoint
,
input_file_id
=
batch_request
.
input_file_id
,
completion_window
=
batch_request
.
completion_window
,
created_at
=
int
(
time
.
time
()),
metadata
=
batch_request
.
metadata
,
)
batch_storage
[
batch_id
]
=
batch_response
# Start processing the batch asynchronously
asyncio
.
create_task
(
process_batch
(
tokenizer_manager
,
batch_id
,
batch_request
))
# Return the initial batch_response
return
batch_response
except
ValidationError
as
e
:
return
{
"error"
:
"Invalid input"
,
"details"
:
e
.
errors
()}
except
Exception
as
e
:
return
{
"error"
:
str
(
e
)}
async
def
process_batch
(
tokenizer_manager
,
batch_id
:
str
,
batch_request
:
BatchRequest
):
try
:
# Update the batch status to "in_progress"
batch_storage
[
batch_id
].
status
=
"in_progress"
batch_storage
[
batch_id
].
in_progress_at
=
int
(
time
.
time
())
# Retrieve the input file content
input_file_request
=
file_id_request
.
get
(
batch_request
.
input_file_id
)
if
not
input_file_request
:
raise
ValueError
(
"Input file not found"
)
# Parse the JSONL file and process each request
input_file_path
=
file_id_storage
.
get
(
batch_request
.
input_file_id
)
with
open
(
input_file_path
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
lines
=
f
.
readlines
()
total_requests
=
len
(
lines
)
completed_requests
=
0
failed_requests
=
0
all_ret
=
[]
end_point
=
batch_storage
[
batch_id
].
endpoint
file_request_list
=
[]
all_requests
=
[]
for
line
in
lines
:
request_data
=
json
.
loads
(
line
)
file_request_list
.
append
(
request_data
)
body
=
request_data
[
"body"
]
if
end_point
==
"/v1/chat/completions"
:
all_requests
.
append
(
ChatCompletionRequest
(
**
body
))
elif
end_point
==
"/v1/completions"
:
all_requests
.
append
(
CompletionRequest
(
**
body
))
if
end_point
==
"/v1/chat/completions"
:
adapted_request
,
request
=
v1_chat_generate_request
(
all_requests
,
tokenizer_manager
)
elif
end_point
==
"/v1/completions"
:
adapted_request
,
request
=
v1_generate_request
(
all_requests
)
try
:
ret
=
await
tokenizer_manager
.
generate_request
(
adapted_request
).
__anext__
()
if
not
isinstance
(
ret
,
list
):
ret
=
[
ret
]
if
end_point
==
"/v1/chat/completions"
:
responses
=
v1_chat_generate_response
(
request
,
ret
,
to_file
=
True
)
else
:
responses
=
v1_generate_response
(
request
,
ret
,
to_file
=
True
)
except
Exception
as
e
:
error_json
=
{
"id"
:
f
"batch_req_
{
uuid
.
uuid4
()
}
"
,
"custom_id"
:
request_data
.
get
(
"custom_id"
),
"response"
:
None
,
"error"
:
{
"message"
:
str
(
e
)},
}
all_ret
.
append
(
error_json
)
failed_requests
+=
len
(
file_request_list
)
for
idx
,
response
in
enumerate
(
responses
):
## the batch_req here can be changed to be named within a batch granularity
response_json
=
{
"id"
:
f
"batch_req_
{
uuid
.
uuid4
()
}
"
,
"custom_id"
:
file_request_list
[
idx
].
get
(
"custom_id"
),
"response"
:
response
,
"error"
:
None
,
}
all_ret
.
append
(
response_json
)
completed_requests
+=
1
# Write results to a new file
output_file_id
=
f
"backend_result_file-
{
uuid
.
uuid4
()
}
"
global
storage_dir
output_file_path
=
os
.
path
.
join
(
storage_dir
,
f
"
{
output_file_id
}
.jsonl"
)
with
open
(
output_file_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
for
ret
in
all_ret
:
f
.
write
(
json
.
dumps
(
ret
)
+
"
\n
"
)
# Update batch response with output file information
retrieve_batch
=
batch_storage
[
batch_id
]
retrieve_batch
.
output_file_id
=
output_file_id
file_id_storage
[
output_file_id
]
=
output_file_path
# Update batch status to "completed"
retrieve_batch
.
status
=
"completed"
retrieve_batch
.
completed_at
=
int
(
time
.
time
())
retrieve_batch
.
request_counts
=
{
"total"
:
total_requests
,
"completed"
:
completed_requests
,
"failed"
:
failed_requests
,
}
except
Exception
as
e
:
print
(
"error in SGlang:"
,
e
)
# Update batch status to "failed"
retrieve_batch
=
batch_storage
[
batch_id
]
retrieve_batch
.
status
=
"failed"
retrieve_batch
.
failed_at
=
int
(
time
.
time
())
retrieve_batch
.
errors
=
{
"message"
:
str
(
e
)}
async
def
v1_retrieve_batch
(
batch_id
:
str
):
# Retrieve the batch job from the in-memory storage
batch_response
=
batch_storage
.
get
(
batch_id
)
if
batch_response
is
None
:
raise
HTTPException
(
status_code
=
404
,
detail
=
"Batch not found"
)
return
batch_response
async
def
v1_retrieve_file
(
file_id
:
str
):
# Retrieve the batch job from the in-memory storage
file_response
=
file_id_response
.
get
(
file_id
)
if
file_response
is
None
:
raise
HTTPException
(
status_code
=
404
,
detail
=
"File not found"
)
return
file_response
async
def
v1_retrieve_file_content
(
file_id
:
str
):
file_pth
=
file_id_storage
.
get
(
file_id
)
if
not
file_pth
or
not
os
.
path
.
exists
(
file_pth
):
raise
HTTPException
(
status_code
=
404
,
detail
=
"File not found"
)
def
iter_file
():
with
open
(
file_pth
,
mode
=
"rb"
)
as
file_like
:
yield
from
file_like
return
StreamingResponse
(
iter_file
(),
media_type
=
"application/octet-stream"
)
def
v1_generate_request
(
all_requests
):
prompts
=
[]
sampling_params_list
=
[]
first_prompt_type
=
type
(
all_requests
[
0
].
prompt
)
for
request
in
all_requests
:
prompt
=
request
.
prompt
assert
(
type
(
prompt
)
==
first_prompt_type
),
"All prompts must be of the same type in file input settings"
prompts
.
append
(
prompt
)
sampling_params_list
.
append
(
{
"temperature"
:
request
.
temperature
,
"temperature"
:
request
.
temperature
,
"max_new_tokens"
:
request
.
max_tokens
,
"max_new_tokens"
:
request
.
max_tokens
,
"stop"
:
request
.
stop
,
"stop"
:
request
.
stop
,
...
@@ -127,12 +351,145 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
...
@@ -127,12 +351,145 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
"regex"
:
request
.
regex
,
"regex"
:
request
.
regex
,
"n"
:
request
.
n
,
"n"
:
request
.
n
,
"ignore_eos"
:
request
.
ignore_eos
,
"ignore_eos"
:
request
.
ignore_eos
,
},
}
return_logprob
=
request
.
logprobs
is
not
None
and
request
.
logprobs
>
0
,
)
top_logprobs_num
=
request
.
logprobs
if
request
.
logprobs
is
not
None
else
0
,
if
len
(
all_requests
)
>
1
and
request
.
n
>
1
:
raise
ValueError
(
"Batch operation is not supported for completions from files"
)
if
len
(
all_requests
)
==
1
:
prompt
=
prompts
[
0
]
sampling_params_list
=
sampling_params_list
[
0
]
if
isinstance
(
prompts
,
str
)
or
isinstance
(
prompts
[
0
],
str
):
prompt_kwargs
=
{
"text"
:
prompt
}
else
:
prompt_kwargs
=
{
"input_ids"
:
prompt
}
else
:
if
isinstance
(
prompts
[
0
],
str
):
prompt_kwargs
=
{
"text"
:
prompts
}
else
:
prompt_kwargs
=
{
"input_ids"
:
prompts
}
adapted_request
=
GenerateReqInput
(
**
prompt_kwargs
,
sampling_params
=
sampling_params_list
,
return_logprob
=
all_requests
[
0
].
logprobs
is
not
None
and
all_requests
[
0
].
logprobs
>
0
,
top_logprobs_num
=
(
all_requests
[
0
].
logprobs
if
all_requests
[
0
].
logprobs
is
not
None
else
0
),
return_text_in_logprobs
=
True
,
return_text_in_logprobs
=
True
,
stream
=
request
.
stream
,
stream
=
all_requests
[
0
].
stream
,
)
if
len
(
all_requests
)
==
1
:
return
adapted_request
,
all_requests
[
0
]
return
adapted_request
,
all_requests
def
v1_generate_response
(
request
,
ret
,
to_file
=
False
):
choices
=
[]
echo
=
False
if
(
not
isinstance
(
request
,
List
))
and
request
.
echo
:
# TODO: handle the case propmt is token ids
if
isinstance
(
request
.
prompt
,
list
):
prompts
=
request
.
prompt
else
:
prompts
=
[
request
.
prompt
]
echo
=
True
for
idx
,
ret_item
in
enumerate
(
ret
):
text
=
ret_item
[
"text"
]
if
isinstance
(
request
,
List
)
and
request
[
idx
].
echo
:
echo
=
True
text
=
request
[
idx
].
prompt
+
text
if
(
not
isinstance
(
request
,
List
))
and
echo
:
text
=
prompts
[
idx
]
+
text
logprobs
=
False
if
isinstance
(
request
,
List
)
and
request
[
idx
].
logprobs
:
logprobs
=
True
elif
(
not
isinstance
(
request
,
List
))
and
request
.
logprobs
:
logprobs
=
True
if
logprobs
:
if
echo
:
input_token_logprobs
=
ret_item
[
"meta_info"
][
"input_token_logprobs"
]
input_top_logprobs
=
ret_item
[
"meta_info"
][
"input_top_logprobs"
]
else
:
input_token_logprobs
=
None
input_top_logprobs
=
None
logprobs
=
to_openai_style_logprobs
(
input_token_logprobs
=
input_token_logprobs
,
input_top_logprobs
=
input_top_logprobs
,
output_token_logprobs
=
ret_item
[
"meta_info"
][
"output_token_logprobs"
],
output_top_logprobs
=
ret_item
[
"meta_info"
][
"output_top_logprobs"
],
)
)
else
:
logprobs
=
None
if
to_file
:
## to make the choise data json serializable
choice_data
=
{
"index"
:
0
,
"text"
:
text
,
"logprobs"
:
logprobs
,
"finish_reason"
:
ret_item
[
"meta_info"
][
"finish_reason"
],
}
else
:
choice_data
=
CompletionResponseChoice
(
index
=
idx
,
text
=
text
,
logprobs
=
logprobs
,
finish_reason
=
ret_item
[
"meta_info"
][
"finish_reason"
],
)
choices
.
append
(
choice_data
)
if
to_file
:
responses
=
[]
for
i
,
choice
in
enumerate
(
choices
):
response
=
{
"status_code"
:
200
,
"request_id"
:
ret
[
i
][
"meta_info"
][
"id"
],
"body"
:
{
## remain the same but if needed we can change that
"id"
:
ret
[
i
][
"meta_info"
][
"id"
],
"object"
:
"text_completion"
,
"created"
:
int
(
time
.
time
()),
"model"
:
request
[
i
].
model
,
"choices"
:
choice
,
"usage"
:
{
"prompt_tokens"
:
ret
[
i
][
"meta_info"
][
"prompt_tokens"
],
"completion_tokens"
:
ret
[
i
][
"meta_info"
][
"completion_tokens"
],
"total_tokens"
:
ret
[
i
][
"meta_info"
][
"prompt_tokens"
]
+
ret
[
i
][
"meta_info"
][
"completion_tokens"
],
},
"system_fingerprint"
:
None
,
},
}
responses
.
append
(
response
)
return
responses
else
:
completion_tokens
=
sum
(
item
[
"meta_info"
][
"completion_tokens"
]
for
item
in
ret
)
response
=
CompletionResponse
(
id
=
ret
[
0
][
"meta_info"
][
"id"
],
model
=
request
.
model
,
choices
=
choices
,
usage
=
UsageInfo
(
prompt_tokens
=
ret
[
0
][
"meta_info"
][
"prompt_tokens"
],
completion_tokens
=
completion_tokens
,
total_tokens
=
ret
[
0
][
"meta_info"
][
"prompt_tokens"
]
+
completion_tokens
,
),
)
return
response
async
def
v1_completions
(
tokenizer_manager
,
raw_request
:
Request
):
request_json
=
await
raw_request
.
json
()
all_requests
=
[
CompletionRequest
(
**
request_json
)]
adapted_request
,
request
=
v1_generate_request
(
all_requests
)
if
adapted_request
.
stream
:
if
adapted_request
.
stream
:
...
@@ -223,65 +580,17 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
...
@@ -223,65 +580,17 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
if
not
isinstance
(
ret
,
list
):
if
not
isinstance
(
ret
,
list
):
ret
=
[
ret
]
ret
=
[
ret
]
if
request
.
echo
:
# TODO: handle the case propmt is token ids
if
isinstance
(
request
.
prompt
,
list
):
prompts
=
request
.
prompt
else
:
prompts
=
[
request
.
prompt
]
choices
=
[]
for
idx
,
ret_item
in
enumerate
(
ret
):
text
=
ret_item
[
"text"
]
if
request
.
echo
:
text
=
prompts
[
idx
]
+
text
if
request
.
logprobs
:
if
request
.
echo
:
input_token_logprobs
=
ret_item
[
"meta_info"
][
"input_token_logprobs"
]
input_top_logprobs
=
ret_item
[
"meta_info"
][
"input_top_logprobs"
]
else
:
input_token_logprobs
=
None
input_top_logprobs
=
None
logprobs
=
to_openai_style_logprobs
(
input_token_logprobs
=
input_token_logprobs
,
input_top_logprobs
=
input_top_logprobs
,
output_token_logprobs
=
ret_item
[
"meta_info"
][
"output_token_logprobs"
],
output_top_logprobs
=
ret_item
[
"meta_info"
][
"output_top_logprobs"
],
)
else
:
logprobs
=
None
choice_data
=
CompletionResponseChoice
(
index
=
idx
,
text
=
text
,
logprobs
=
logprobs
,
finish_reason
=
ret_item
[
"meta_info"
][
"finish_reason"
],
)
choices
.
append
(
choice_data
)
completion_tokens
=
sum
(
item
[
"meta_info"
][
"completion_tokens"
]
for
item
in
ret
)
response
=
CompletionResponse
(
id
=
ret
[
0
][
"meta_info"
][
"id"
],
model
=
request
.
model
,
choices
=
choices
,
usage
=
UsageInfo
(
prompt_tokens
=
ret
[
0
][
"meta_info"
][
"prompt_tokens"
],
completion_tokens
=
completion_tokens
,
total_tokens
=
ret
[
0
][
"meta_info"
][
"prompt_tokens"
]
+
completion_tokens
,
),
)
response
=
v1_generate_response
(
request
,
ret
)
return
response
return
response
async
def
v1_chat_completions
(
tokenizer_manager
,
raw_request
:
Request
):
def
v1_chat_generate_request
(
all_requests
,
tokenizer_manager
):
request_json
=
await
raw_request
.
json
()
request
=
ChatCompletionRequest
(
**
request_json
)
texts
=
[]
sampling_params_list
=
[]
image_data_list
=
[]
for
request
in
all_requests
:
# Prep the data needed for the underlying GenerateReqInput:
# Prep the data needed for the underlying GenerateReqInput:
# - prompt: The full prompt string.
# - prompt: The full prompt string.
# - stop: Custom stop tokens.
# - stop: Custom stop tokens.
...
@@ -310,11 +619,9 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
...
@@ -310,11 +619,9 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
prompt
=
request
.
messages
prompt
=
request
.
messages
stop
=
request
.
stop
stop
=
request
.
stop
image_data
=
None
image_data
=
None
texts
.
append
(
prompt
)
adapted_request
=
GenerateReqInput
(
sampling_params_list
.
append
(
text
=
prompt
,
{
image_data
=
image_data
,
sampling_params
=
{
"temperature"
:
request
.
temperature
,
"temperature"
:
request
.
temperature
,
"max_new_tokens"
:
request
.
max_tokens
,
"max_new_tokens"
:
request
.
max_tokens
,
"stop"
:
stop
,
"stop"
:
stop
,
...
@@ -323,9 +630,94 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
...
@@ -323,9 +630,94 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
"frequency_penalty"
:
request
.
frequency_penalty
,
"frequency_penalty"
:
request
.
frequency_penalty
,
"regex"
:
request
.
regex
,
"regex"
:
request
.
regex
,
"n"
:
request
.
n
,
"n"
:
request
.
n
,
},
}
)
image_data_list
.
append
(
image_data
)
if
len
(
all_requests
)
==
1
:
texts
=
texts
[
0
]
sampling_params_list
=
sampling_params_list
[
0
]
image_data
=
image_data_list
[
0
]
adapted_request
=
GenerateReqInput
(
text
=
texts
,
image_data
=
image_data
,
sampling_params
=
sampling_params_list
,
stream
=
request
.
stream
,
stream
=
request
.
stream
,
)
)
if
len
(
all_requests
)
==
1
:
return
adapted_request
,
all_requests
[
0
]
return
adapted_request
,
all_requests
def
v1_chat_generate_response
(
request
,
ret
,
to_file
=
False
):
choices
=
[]
total_prompt_tokens
=
0
total_completion_tokens
=
0
for
idx
,
ret_item
in
enumerate
(
ret
):
prompt_tokens
=
ret_item
[
"meta_info"
][
"prompt_tokens"
]
completion_tokens
=
ret_item
[
"meta_info"
][
"completion_tokens"
]
if
to_file
:
## to make the choice data json serializable
choice_data
=
{
"index"
:
0
,
"message"
:
{
"role"
:
"assistant"
,
"content"
:
ret_item
[
"text"
]},
"logprobs"
:
None
,
"finish_reason"
:
ret_item
[
"meta_info"
][
"finish_reason"
],
}
else
:
choice_data
=
ChatCompletionResponseChoice
(
index
=
idx
,
message
=
ChatMessage
(
role
=
"assistant"
,
content
=
ret_item
[
"text"
]),
finish_reason
=
ret_item
[
"meta_info"
][
"finish_reason"
],
)
choices
.
append
(
choice_data
)
total_prompt_tokens
=
prompt_tokens
total_completion_tokens
+=
completion_tokens
if
to_file
:
responses
=
[]
for
i
,
choice
in
enumerate
(
choices
):
response
=
{
"status_code"
:
200
,
"request_id"
:
ret
[
i
][
"meta_info"
][
"id"
],
"body"
:
{
## remain the same but if needed we can change that
"id"
:
ret
[
i
][
"meta_info"
][
"id"
],
"object"
:
"chat.completion"
,
"created"
:
int
(
time
.
time
()),
"model"
:
request
[
i
].
model
,
"choices"
:
choice
,
"usage"
:
{
"prompt_tokens"
:
ret
[
i
][
"meta_info"
][
"prompt_tokens"
],
"completion_tokens"
:
ret
[
i
][
"meta_info"
][
"completion_tokens"
],
"total_tokens"
:
ret
[
i
][
"meta_info"
][
"prompt_tokens"
]
+
ret
[
i
][
"meta_info"
][
"completion_tokens"
],
},
"system_fingerprint"
:
None
,
},
}
responses
.
append
(
response
)
return
responses
else
:
response
=
ChatCompletionResponse
(
id
=
ret
[
0
][
"meta_info"
][
"id"
],
model
=
request
.
model
,
choices
=
choices
,
usage
=
UsageInfo
(
prompt_tokens
=
total_prompt_tokens
,
completion_tokens
=
total_completion_tokens
,
total_tokens
=
total_prompt_tokens
+
total_completion_tokens
,
),
)
return
response
async
def
v1_chat_completions
(
tokenizer_manager
,
raw_request
:
Request
):
request_json
=
await
raw_request
.
json
()
all_requests
=
[
ChatCompletionRequest
(
**
request_json
)]
adapted_request
,
request
=
v1_chat_generate_request
(
all_requests
,
tokenizer_manager
)
if
adapted_request
.
stream
:
if
adapted_request
.
stream
:
...
@@ -387,34 +779,8 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
...
@@ -387,34 +779,8 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
if
not
isinstance
(
ret
,
list
):
if
not
isinstance
(
ret
,
list
):
ret
=
[
ret
]
ret
=
[
ret
]
choices
=
[]
total_prompt_tokens
=
0
total_completion_tokens
=
0
for
idx
,
ret_item
in
enumerate
(
ret
):
prompt_tokens
=
ret_item
[
"meta_info"
][
"prompt_tokens"
]
completion_tokens
=
ret_item
[
"meta_info"
][
"completion_tokens"
]
choice_data
=
ChatCompletionResponseChoice
(
index
=
idx
,
message
=
ChatMessage
(
role
=
"assistant"
,
content
=
ret_item
[
"text"
]),
finish_reason
=
ret_item
[
"meta_info"
][
"finish_reason"
],
)
choices
.
append
(
choice_data
)
response
=
v1_chat_generate_response
(
request
,
ret
)
total_prompt_tokens
=
prompt_tokens
total_completion_tokens
+=
completion_tokens
response
=
ChatCompletionResponse
(
id
=
ret
[
0
][
"meta_info"
][
"id"
],
model
=
request
.
model
,
choices
=
choices
,
usage
=
UsageInfo
(
prompt_tokens
=
total_prompt_tokens
,
completion_tokens
=
total_completion_tokens
,
total_tokens
=
total_prompt_tokens
+
total_completion_tokens
,
),
)
return
response
return
response
...
...
python/sglang/srt/openai_api/protocol.py
View file @
084fa54d
...
@@ -60,6 +60,55 @@ class UsageInfo(BaseModel):
...
@@ -60,6 +60,55 @@ class UsageInfo(BaseModel):
completion_tokens
:
Optional
[
int
]
=
0
completion_tokens
:
Optional
[
int
]
=
0
class
FileRequest
(
BaseModel
):
# https://platform.openai.com/docs/api-reference/files/create
file
:
bytes
# The File object (not file name) to be uploaded
purpose
:
str
=
(
"batch"
# The intended purpose of the uploaded file, default is "batch"
)
class
FileResponse
(
BaseModel
):
id
:
str
object
:
str
=
"file"
bytes
:
int
created_at
:
int
filename
:
str
purpose
:
str
class
BatchRequest
(
BaseModel
):
input_file_id
:
(
str
# The ID of an uploaded file that contains requests for the new batch
)
endpoint
:
str
# The endpoint to be used for all requests in the batch
completion_window
:
str
# The time frame within which the batch should be processed
metadata
:
Optional
[
dict
]
=
None
# Optional custom metadata for the batch
class
BatchResponse
(
BaseModel
):
id
:
str
object
:
str
=
"batch"
endpoint
:
str
errors
:
Optional
[
dict
]
=
None
input_file_id
:
str
completion_window
:
str
status
:
str
=
"validating"
output_file_id
:
Optional
[
str
]
=
None
error_file_id
:
Optional
[
str
]
=
None
created_at
:
int
in_progress_at
:
Optional
[
int
]
=
None
expires_at
:
Optional
[
int
]
=
None
finalizing_at
:
Optional
[
int
]
=
None
completed_at
:
Optional
[
int
]
=
None
failed_at
:
Optional
[
int
]
=
None
expired_at
:
Optional
[
int
]
=
None
cancelling_at
:
Optional
[
int
]
=
None
cancelled_at
:
Optional
[
int
]
=
None
request_counts
:
dict
=
{
"total"
:
0
,
"completed"
:
0
,
"failed"
:
0
}
metadata
:
Optional
[
dict
]
=
None
class
CompletionRequest
(
BaseModel
):
class
CompletionRequest
(
BaseModel
):
# Ordered by official OpenAI API documentation
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/completions/create
# https://platform.openai.com/docs/api-reference/completions/create
...
...
python/sglang/srt/server.py
View file @
084fa54d
...
@@ -38,7 +38,7 @@ import psutil
...
@@ -38,7 +38,7 @@ import psutil
import
requests
import
requests
import
uvicorn
import
uvicorn
import
uvloop
import
uvloop
from
fastapi
import
FastAPI
,
Request
from
fastapi
import
FastAPI
,
File
,
Form
,
Request
,
UploadFile
from
fastapi.responses
import
JSONResponse
,
Response
,
StreamingResponse
from
fastapi.responses
import
JSONResponse
,
Response
,
StreamingResponse
from
sglang.lang.backend.runtime_endpoint
import
RuntimeEndpoint
from
sglang.lang.backend.runtime_endpoint
import
RuntimeEndpoint
...
@@ -56,8 +56,13 @@ from sglang.srt.managers.io_struct import GenerateReqInput
...
@@ -56,8 +56,13 @@ from sglang.srt.managers.io_struct import GenerateReqInput
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
from
sglang.srt.openai_api.adapter
import
(
from
sglang.srt.openai_api.adapter
import
(
load_chat_template_for_openai_api
,
load_chat_template_for_openai_api
,
v1_batches
,
v1_chat_completions
,
v1_chat_completions
,
v1_completions
,
v1_completions
,
v1_files_create
,
v1_retrieve_batch
,
v1_retrieve_file
,
v1_retrieve_file_content
,
)
)
from
sglang.srt.openai_api.protocol
import
ModelCard
,
ModelList
from
sglang.srt.openai_api.protocol
import
ModelCard
,
ModelList
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
...
@@ -152,6 +157,35 @@ async def openai_v1_chat_completions(raw_request: Request):
...
@@ -152,6 +157,35 @@ async def openai_v1_chat_completions(raw_request: Request):
return
await
v1_chat_completions
(
tokenizer_manager
,
raw_request
)
return
await
v1_chat_completions
(
tokenizer_manager
,
raw_request
)
@
app
.
post
(
"/v1/files"
)
async
def
openai_v1_files
(
file
:
UploadFile
=
File
(...),
purpose
:
str
=
Form
(
"batch"
)):
return
await
v1_files_create
(
file
,
purpose
,
tokenizer_manager
.
server_args
.
file_storage_pth
)
@
app
.
post
(
"/v1/batches"
)
async
def
openai_v1_batches
(
raw_request
:
Request
):
return
await
v1_batches
(
tokenizer_manager
,
raw_request
)
@
app
.
get
(
"/v1/batches/{batch_id}"
)
async
def
retrieve_batch
(
batch_id
:
str
):
return
await
v1_retrieve_batch
(
batch_id
)
@
app
.
get
(
"/v1/files/{file_id}"
)
async
def
retrieve_file
(
file_id
:
str
):
# https://platform.openai.com/docs/api-reference/files/retrieve
return
await
v1_retrieve_file
(
file_id
)
@
app
.
get
(
"/v1/files/{file_id}/content"
)
async
def
retrieve_file_content
(
file_id
:
str
):
# https://platform.openai.com/docs/api-reference/files/retrieve-contents
return
await
v1_retrieve_file_content
(
file_id
)
@
app
.
get
(
"/v1/models"
)
@
app
.
get
(
"/v1/models"
)
def
available_models
():
def
available_models
():
"""Show available models."""
"""Show available models."""
...
...
python/sglang/srt/server_args.py
View file @
084fa54d
...
@@ -60,6 +60,7 @@ class ServerArgs:
...
@@ -60,6 +60,7 @@ class ServerArgs:
# Other
# Other
api_key
:
str
=
""
api_key
:
str
=
""
file_storage_pth
:
str
=
"SGlang_storage"
# Data parallelism
# Data parallelism
dp_size
:
int
=
1
dp_size
:
int
=
1
...
@@ -290,6 +291,12 @@ class ServerArgs:
...
@@ -290,6 +291,12 @@ class ServerArgs:
default
=
ServerArgs
.
api_key
,
default
=
ServerArgs
.
api_key
,
help
=
"Set API key of the server."
,
help
=
"Set API key of the server."
,
)
)
parser
.
add_argument
(
"--file-storage-pth"
,
type
=
str
,
default
=
ServerArgs
.
file_storage_pth
,
help
=
"The path of the file storage in backend."
,
)
# Data parallelism
# Data parallelism
parser
.
add_argument
(
parser
.
add_argument
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment