Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0d4f3a9f
Unverified
Commit
0d4f3a9f
authored
Aug 04, 2024
by
Ying Sheng
Committed by
GitHub
Aug 04, 2024
Browse files
Make API Key OpenAI-compatible (#917)
parent
afd411d0
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
115 additions
and
125 deletions
+115
-125
python/sglang/lang/backend/runtime_endpoint.py
python/sglang/lang/backend/runtime_endpoint.py
+0
-11
python/sglang/srt/server.py
python/sglang/srt/server.py
+63
-75
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+2
-2
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+25
-20
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+13
-2
python/sglang/utils.py
python/sglang/utils.py
+3
-9
test/srt/test_openai_server.py
test/srt/test_openai_server.py
+9
-6
No files found.
python/sglang/lang/backend/runtime_endpoint.py
View file @
0d4f3a9f
...
@@ -15,7 +15,6 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -15,7 +15,6 @@ class RuntimeEndpoint(BaseBackend):
def
__init__
(
def
__init__
(
self
,
self
,
base_url
:
str
,
base_url
:
str
,
auth_token
:
Optional
[
str
]
=
None
,
api_key
:
Optional
[
str
]
=
None
,
api_key
:
Optional
[
str
]
=
None
,
verify
:
Optional
[
str
]
=
None
,
verify
:
Optional
[
str
]
=
None
,
):
):
...
@@ -23,13 +22,11 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -23,13 +22,11 @@ class RuntimeEndpoint(BaseBackend):
self
.
support_concate_and_append
=
True
self
.
support_concate_and_append
=
True
self
.
base_url
=
base_url
self
.
base_url
=
base_url
self
.
auth_token
=
auth_token
self
.
api_key
=
api_key
self
.
api_key
=
api_key
self
.
verify
=
verify
self
.
verify
=
verify
res
=
http_request
(
res
=
http_request
(
self
.
base_url
+
"/get_model_info"
,
self
.
base_url
+
"/get_model_info"
,
auth_token
=
self
.
auth_token
,
api_key
=
self
.
api_key
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
verify
=
self
.
verify
,
)
)
...
@@ -67,7 +64,6 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -67,7 +64,6 @@ class RuntimeEndpoint(BaseBackend):
res
=
http_request
(
res
=
http_request
(
self
.
base_url
+
"/generate"
,
self
.
base_url
+
"/generate"
,
json
=
{
"text"
:
prefix_str
,
"sampling_params"
:
{
"max_new_tokens"
:
0
}},
json
=
{
"text"
:
prefix_str
,
"sampling_params"
:
{
"max_new_tokens"
:
0
}},
auth_token
=
self
.
auth_token
,
api_key
=
self
.
api_key
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
verify
=
self
.
verify
,
)
)
...
@@ -79,7 +75,6 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -79,7 +75,6 @@ class RuntimeEndpoint(BaseBackend):
res
=
http_request
(
res
=
http_request
(
self
.
base_url
+
"/generate"
,
self
.
base_url
+
"/generate"
,
json
=
data
,
json
=
data
,
auth_token
=
self
.
auth_token
,
api_key
=
self
.
api_key
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
verify
=
self
.
verify
,
)
)
...
@@ -91,7 +86,6 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -91,7 +86,6 @@ class RuntimeEndpoint(BaseBackend):
res
=
http_request
(
res
=
http_request
(
self
.
base_url
+
"/generate"
,
self
.
base_url
+
"/generate"
,
json
=
data
,
json
=
data
,
auth_token
=
self
.
auth_token
,
api_key
=
self
.
api_key
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
verify
=
self
.
verify
,
)
)
...
@@ -139,7 +133,6 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -139,7 +133,6 @@ class RuntimeEndpoint(BaseBackend):
res
=
http_request
(
res
=
http_request
(
self
.
base_url
+
"/generate"
,
self
.
base_url
+
"/generate"
,
json
=
data
,
json
=
data
,
auth_token
=
self
.
auth_token
,
api_key
=
self
.
api_key
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
verify
=
self
.
verify
,
)
)
...
@@ -193,7 +186,6 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -193,7 +186,6 @@ class RuntimeEndpoint(BaseBackend):
self
.
base_url
+
"/generate"
,
self
.
base_url
+
"/generate"
,
json
=
data
,
json
=
data
,
stream
=
True
,
stream
=
True
,
auth_token
=
self
.
auth_token
,
api_key
=
self
.
api_key
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
verify
=
self
.
verify
,
)
)
...
@@ -225,7 +217,6 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -225,7 +217,6 @@ class RuntimeEndpoint(BaseBackend):
res
=
http_request
(
res
=
http_request
(
self
.
base_url
+
"/generate"
,
self
.
base_url
+
"/generate"
,
json
=
data
,
json
=
data
,
auth_token
=
self
.
auth_token
,
api_key
=
self
.
api_key
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
verify
=
self
.
verify
,
)
)
...
@@ -243,7 +234,6 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -243,7 +234,6 @@ class RuntimeEndpoint(BaseBackend):
res
=
http_request
(
res
=
http_request
(
self
.
base_url
+
"/generate"
,
self
.
base_url
+
"/generate"
,
json
=
data
,
json
=
data
,
auth_token
=
self
.
auth_token
,
api_key
=
self
.
api_key
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
verify
=
self
.
verify
,
)
)
...
@@ -267,7 +257,6 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -267,7 +257,6 @@ class RuntimeEndpoint(BaseBackend):
res
=
http_request
(
res
=
http_request
(
self
.
base_url
+
"/concate_and_append_request"
,
self
.
base_url
+
"/concate_and_append_request"
,
json
=
{
"src_rids"
:
src_rids
,
"dst_rid"
:
dst_rid
},
json
=
{
"src_rids"
:
src_rids
,
"dst_rid"
:
dst_rid
},
auth_token
=
self
.
auth_token
,
api_key
=
self
.
api_key
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
verify
=
self
.
verify
,
)
)
...
...
python/sglang/srt/server.py
View file @
0d4f3a9f
...
@@ -67,13 +67,13 @@ from sglang.srt.openai_api.adapter import (
...
@@ -67,13 +67,13 @@ from sglang.srt.openai_api.adapter import (
from
sglang.srt.openai_api.protocol
import
ModelCard
,
ModelList
from
sglang.srt.openai_api.protocol
import
ModelCard
,
ModelList
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
API_KEY_HEADER_NAME
,
add_api_key_middleware
,
APIKeyValidatorMiddleware
,
allocate_init_ports
,
allocate_init_ports
,
assert_pkg_version
,
assert_pkg_version
,
enable_show_time_cost
,
enable_show_time_cost
,
kill_child_process
,
kill_child_process
,
maybe_set_triton_cache_manager
,
maybe_set_triton_cache_manager
,
set_torch_compile_config
,
set_ulimit
,
set_ulimit
,
)
)
from
sglang.utils
import
get_exception_traceback
from
sglang.utils
import
get_exception_traceback
...
@@ -158,6 +158,16 @@ async def openai_v1_chat_completions(raw_request: Request):
...
@@ -158,6 +158,16 @@ async def openai_v1_chat_completions(raw_request: Request):
return
await
v1_chat_completions
(
tokenizer_manager
,
raw_request
)
return
await
v1_chat_completions
(
tokenizer_manager
,
raw_request
)
@
app
.
get
(
"/v1/models"
)
def
available_models
():
"""Show available models."""
served_model_names
=
[
tokenizer_manager
.
served_model_name
]
model_cards
=
[]
for
served_model_name
in
served_model_names
:
model_cards
.
append
(
ModelCard
(
id
=
served_model_name
,
root
=
served_model_name
))
return
ModelList
(
data
=
model_cards
)
@
app
.
post
(
"/v1/files"
)
@
app
.
post
(
"/v1/files"
)
async
def
openai_v1_files
(
file
:
UploadFile
=
File
(...),
purpose
:
str
=
Form
(
"batch"
)):
async
def
openai_v1_files
(
file
:
UploadFile
=
File
(...),
purpose
:
str
=
Form
(
"batch"
)):
return
await
v1_files_create
(
return
await
v1_files_create
(
...
@@ -187,69 +197,11 @@ async def retrieve_file_content(file_id: str):
...
@@ -187,69 +197,11 @@ async def retrieve_file_content(file_id: str):
return
await
v1_retrieve_file_content
(
file_id
)
return
await
v1_retrieve_file_content
(
file_id
)
@
app
.
get
(
"/v1/models"
)
def
available_models
():
"""Show available models."""
served_model_names
=
[
tokenizer_manager
.
served_model_name
]
model_cards
=
[]
for
served_model_name
in
served_model_names
:
model_cards
.
append
(
ModelCard
(
id
=
served_model_name
,
root
=
served_model_name
))
return
ModelList
(
data
=
model_cards
)
def
_set_torch_compile_config
():
# The following configurations are for torch compile optimizations
import
torch._dynamo.config
import
torch._inductor.config
torch
.
_inductor
.
config
.
coordinate_descent_tuning
=
True
torch
.
_inductor
.
config
.
triton
.
unique_kernel_names
=
True
torch
.
_inductor
.
config
.
fx_graph_cache
=
True
# Experimental feature to reduce compilation times, will be on by default in future
# FIXME: tmp workaround
torch
.
_dynamo
.
config
.
accumulated_cache_size_limit
=
256
def
set_envs_and_config
(
server_args
:
ServerArgs
):
# Set global environments
os
.
environ
[
"TF_CPP_MIN_LOG_LEVEL"
]
=
"3"
os
.
environ
[
"NCCL_CUMEM_ENABLE"
]
=
"0"
os
.
environ
[
"NCCL_NVLS_ENABLE"
]
=
"0"
os
.
environ
[
"TORCH_NCCL_AVOID_RECORD_STREAMS"
]
=
"1"
# Set ulimit
set_ulimit
()
# Enable show time cost for debugging
if
server_args
.
show_time_cost
:
enable_show_time_cost
()
# Disable disk cache
if
server_args
.
disable_disk_cache
:
disable_cache
()
# Fix triton bugs
if
server_args
.
tp_size
*
server_args
.
dp_size
>
1
:
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
maybe_set_triton_cache_manager
()
# Set torch compile config
if
server_args
.
enable_torch_compile
:
_set_torch_compile_config
()
# Set global chat template
if
server_args
.
chat_template
:
# TODO: replace this with huggingface transformers template
load_chat_template_for_openai_api
(
server_args
.
chat_template
)
def
launch_server
(
def
launch_server
(
server_args
:
ServerArgs
,
server_args
:
ServerArgs
,
model_overide_args
:
Optional
[
dict
]
=
None
,
model_overide_args
:
Optional
[
dict
]
=
None
,
pipe_finish_writer
:
Optional
[
mp
.
connection
.
Connection
]
=
None
,
pipe_finish_writer
:
Optional
[
mp
.
connection
.
Connection
]
=
None
,
):
):
server_args
.
check_server_args
()
"""Launch an HTTP server."""
"""Launch an HTTP server."""
global
tokenizer_manager
global
tokenizer_manager
...
@@ -258,16 +210,8 @@ def launch_server(
...
@@ -258,16 +210,8 @@ def launch_server(
format
=
"%(message)s"
,
format
=
"%(message)s"
,
)
)
if
not
server_args
.
disable_flashinfer
:
server_args
.
check_server_args
()
assert_pkg_version
(
_set_envs_and_config
(
server_args
)
"flashinfer"
,
"0.1.3"
,
"Please uninstall the old version and "
"reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html."
,
)
set_envs_and_config
(
server_args
)
# Allocate ports
# Allocate ports
server_args
.
port
,
server_args
.
additional_ports
=
allocate_init_ports
(
server_args
.
port
,
server_args
.
additional_ports
=
allocate_init_ports
(
...
@@ -284,7 +228,7 @@ def launch_server(
...
@@ -284,7 +228,7 @@ def launch_server(
)
)
logger
.
info
(
f
"
{
server_args
=
}
"
)
logger
.
info
(
f
"
{
server_args
=
}
"
)
#
Handle
multi-node tensor parallelism
#
Launch processes for
multi-node tensor parallelism
if
server_args
.
nnodes
>
1
:
if
server_args
.
nnodes
>
1
:
if
server_args
.
node_rank
!=
0
:
if
server_args
.
node_rank
!=
0
:
tp_size_local
=
server_args
.
tp_size
//
server_args
.
nnodes
tp_size_local
=
server_args
.
tp_size
//
server_args
.
nnodes
...
@@ -349,8 +293,9 @@ def launch_server(
...
@@ -349,8 +293,9 @@ def launch_server(
sys
.
exit
(
1
)
sys
.
exit
(
1
)
assert
proc_controller
.
is_alive
()
and
proc_detoken
.
is_alive
()
assert
proc_controller
.
is_alive
()
and
proc_detoken
.
is_alive
()
if
server_args
.
api_key
and
server_args
.
api_key
!=
""
:
# Add api key authorization
app
.
add_middleware
(
APIKeyValidatorMiddleware
,
api_key
=
server_args
.
api_key
)
if
server_args
.
api_key
:
add_api_key_middleware
(
app
,
server_args
.
api_key
)
# Send a warmup request
# Send a warmup request
t
=
threading
.
Thread
(
t
=
threading
.
Thread
(
...
@@ -372,15 +317,58 @@ def launch_server(
...
@@ -372,15 +317,58 @@ def launch_server(
t
.
join
()
t
.
join
()
def
_set_envs_and_config
(
server_args
:
ServerArgs
):
# Set global environments
os
.
environ
[
"TF_CPP_MIN_LOG_LEVEL"
]
=
"3"
os
.
environ
[
"NCCL_CUMEM_ENABLE"
]
=
"0"
os
.
environ
[
"NCCL_NVLS_ENABLE"
]
=
"0"
os
.
environ
[
"TORCH_NCCL_AVOID_RECORD_STREAMS"
]
=
"1"
# Set ulimit
set_ulimit
()
# Enable show time cost for debugging
if
server_args
.
show_time_cost
:
enable_show_time_cost
()
# Disable disk cache
if
server_args
.
disable_disk_cache
:
disable_cache
()
# Fix triton bugs
if
server_args
.
tp_size
*
server_args
.
dp_size
>
1
:
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
maybe_set_triton_cache_manager
()
# Set torch compile config
if
server_args
.
enable_torch_compile
:
set_torch_compile_config
()
# Set global chat template
if
server_args
.
chat_template
:
# TODO: replace this with huggingface transformers template
load_chat_template_for_openai_api
(
server_args
.
chat_template
)
# Check flashinfer version
if
not
server_args
.
disable_flashinfer
:
assert_pkg_version
(
"flashinfer"
,
"0.1.3"
,
"Please uninstall the old version and "
"reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html."
,
)
def
_wait_and_warmup
(
server_args
,
pipe_finish_writer
):
def
_wait_and_warmup
(
server_args
,
pipe_finish_writer
):
headers
=
{}
headers
=
{}
url
=
server_args
.
url
()
url
=
server_args
.
url
()
if
server_args
.
api_key
:
if
server_args
.
api_key
:
headers
[
API_KEY_HEADER_NAME
]
=
server_args
.
api_key
headers
[
"Authorization"
]
=
f
"Bearer
{
server_args
.
api_key
}
"
# Wait until the server is launched
# Wait until the server is launched
for
_
in
range
(
120
):
for
_
in
range
(
120
):
time
.
sleep
(
0.5
)
time
.
sleep
(
1
)
try
:
try
:
requests
.
get
(
url
+
"/get_model_info"
,
timeout
=
5
,
headers
=
headers
)
requests
.
get
(
url
+
"/get_model_info"
,
timeout
=
5
,
headers
=
headers
)
break
break
...
...
python/sglang/srt/server_args.py
View file @
0d4f3a9f
...
@@ -61,7 +61,7 @@ class ServerArgs:
...
@@ -61,7 +61,7 @@ class ServerArgs:
show_time_cost
:
bool
=
False
show_time_cost
:
bool
=
False
# Other
# Other
api_key
:
str
=
""
api_key
:
Optional
[
str
]
=
None
file_storage_pth
:
str
=
"SGlang_storage"
file_storage_pth
:
str
=
"SGlang_storage"
# Data parallelism
# Data parallelism
...
@@ -307,7 +307,7 @@ class ServerArgs:
...
@@ -307,7 +307,7 @@ class ServerArgs:
"--api-key"
,
"--api-key"
,
type
=
str
,
type
=
str
,
default
=
ServerArgs
.
api_key
,
default
=
ServerArgs
.
api_key
,
help
=
"Set API key of the server."
,
help
=
"Set API key of the
server. It is also used in the OpenAI API compatible
server."
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--file-storage-pth"
,
"--file-storage-pth"
,
...
...
python/sglang/srt/utils.py
View file @
0d4f3a9f
...
@@ -539,26 +539,6 @@ class CustomCacheManager(FileCacheManager):
...
@@ -539,26 +539,6 @@ class CustomCacheManager(FileCacheManager):
raise
RuntimeError
(
"Could not create or locate cache dir"
)
raise
RuntimeError
(
"Could not create or locate cache dir"
)
API_KEY_HEADER_NAME
=
"X-API-Key"
class
APIKeyValidatorMiddleware
(
BaseHTTPMiddleware
):
def
__init__
(
self
,
app
,
api_key
:
str
):
super
().
__init__
(
app
)
self
.
api_key
=
api_key
async
def
dispatch
(
self
,
request
,
call_next
):
# extract API key from the request headers
api_key_header
=
request
.
headers
.
get
(
API_KEY_HEADER_NAME
)
if
not
api_key_header
or
api_key_header
!=
self
.
api_key
:
return
JSONResponse
(
status_code
=
403
,
content
=
{
"detail"
:
"Invalid API Key"
},
)
response
=
await
call_next
(
request
)
return
response
def
get_ip_address
(
ifname
):
def
get_ip_address
(
ifname
):
"""
"""
Get the IP address of a network interface.
Get the IP address of a network interface.
...
@@ -642,6 +622,19 @@ def receive_addrs(model_port_args, server_args):
...
@@ -642,6 +622,19 @@ def receive_addrs(model_port_args, server_args):
dist
.
destroy_process_group
()
dist
.
destroy_process_group
()
def
set_torch_compile_config
():
# The following configurations are for torch compile optimizations
import
torch._dynamo.config
import
torch._inductor.config
torch
.
_inductor
.
config
.
coordinate_descent_tuning
=
True
torch
.
_inductor
.
config
.
triton
.
unique_kernel_names
=
True
torch
.
_inductor
.
config
.
fx_graph_cache
=
True
# Experimental feature to reduce compilation times, will be on by default in future
# FIXME: tmp workaround
torch
.
_dynamo
.
config
.
accumulated_cache_size_limit
=
256
def
set_ulimit
(
target_soft_limit
=
65535
):
def
set_ulimit
(
target_soft_limit
=
65535
):
resource_type
=
resource
.
RLIMIT_NOFILE
resource_type
=
resource
.
RLIMIT_NOFILE
current_soft
,
current_hard
=
resource
.
getrlimit
(
resource_type
)
current_soft
,
current_hard
=
resource
.
getrlimit
(
resource_type
)
...
@@ -700,3 +693,15 @@ def monkey_patch_vllm_qvk_linear_loader():
...
@@ -700,3 +693,15 @@ def monkey_patch_vllm_qvk_linear_loader():
origin_weight_loader
(
self
,
param
,
loaded_weight
,
loaded_shard_id
)
origin_weight_loader
(
self
,
param
,
loaded_weight
,
loaded_shard_id
)
setattr
(
QKVParallelLinear
,
"weight_loader"
,
weight_loader_srt
)
setattr
(
QKVParallelLinear
,
"weight_loader"
,
weight_loader_srt
)
def
add_api_key_middleware
(
app
,
api_key
):
@
app
.
middleware
(
"http"
)
async
def
authentication
(
request
,
call_next
):
if
request
.
method
==
"OPTIONS"
:
return
await
call_next
(
request
)
if
request
.
url
.
path
.
startswith
(
"/health"
):
return
await
call_next
(
request
)
if
request
.
headers
.
get
(
"Authorization"
)
!=
"Bearer "
+
api_key
:
return
JSONResponse
(
content
=
{
"error"
:
"Unauthorized"
},
status_code
=
401
)
return
await
call_next
(
request
)
python/sglang/test/test_utils.py
View file @
0d4f3a9f
...
@@ -391,7 +391,11 @@ def get_call_select(args: argparse.Namespace):
...
@@ -391,7 +391,11 @@ def get_call_select(args: argparse.Namespace):
def
popen_launch_server
(
def
popen_launch_server
(
model
:
str
,
base_url
:
str
,
timeout
:
float
,
other_args
:
tuple
=
()
model
:
str
,
base_url
:
str
,
timeout
:
float
,
api_key
:
Optional
[
str
]
=
None
,
other_args
:
tuple
=
(),
):
):
_
,
host
,
port
=
base_url
.
split
(
":"
)
_
,
host
,
port
=
base_url
.
split
(
":"
)
host
=
host
[
2
:]
host
=
host
[
2
:]
...
@@ -408,12 +412,19 @@ def popen_launch_server(
...
@@ -408,12 +412,19 @@ def popen_launch_server(
port
,
port
,
*
other_args
,
*
other_args
,
]
]
if
api_key
:
command
+=
[
"--api-key"
,
api_key
]
process
=
subprocess
.
Popen
(
command
,
stdout
=
None
,
stderr
=
None
)
process
=
subprocess
.
Popen
(
command
,
stdout
=
None
,
stderr
=
None
)
start_time
=
time
.
time
()
start_time
=
time
.
time
()
while
time
.
time
()
-
start_time
<
timeout
:
while
time
.
time
()
-
start_time
<
timeout
:
try
:
try
:
response
=
requests
.
get
(
f
"
{
base_url
}
/v1/models"
)
headers
=
{
"Content-Type"
:
"application/json; charset=utf-8"
,
"Authorization"
:
f
"Bearer
{
api_key
}
"
,
}
response
=
requests
.
get
(
f
"
{
base_url
}
/v1/models"
,
headers
=
headers
)
if
response
.
status_code
==
200
:
if
response
.
status_code
==
200
:
return
process
return
process
except
requests
.
RequestException
:
except
requests
.
RequestException
:
...
...
python/sglang/utils.py
View file @
0d4f3a9f
...
@@ -76,19 +76,13 @@ class HttpResponse:
...
@@ -76,19 +76,13 @@ class HttpResponse:
return
self
.
resp
.
status
return
self
.
resp
.
status
def
http_request
(
def
http_request
(
url
,
json
=
None
,
stream
=
False
,
api_key
=
None
,
verify
=
None
):
url
,
json
=
None
,
stream
=
False
,
auth_token
=
None
,
api_key
=
None
,
verify
=
None
):
"""A faster version of requests.post with low-level urllib API."""
"""A faster version of requests.post with low-level urllib API."""
headers
=
{
"Content-Type"
:
"application/json; charset=utf-8"
}
headers
=
{
"Content-Type"
:
"application/json; charset=utf-8"
}
# add the Authorization header if an auth token is provided
# add the Authorization header if an api key is provided
if
auth_token
is
not
None
:
headers
[
"Authorization"
]
=
f
"Bearer
{
auth_token
}
"
# add the API Key header if an API key is provided
if
api_key
is
not
None
:
if
api_key
is
not
None
:
headers
[
"
X-API-Key"
]
=
api_key
headers
[
"
Authorization"
]
=
f
"Bearer
{
api_key
}
"
if
stream
:
if
stream
:
return
requests
.
post
(
url
,
json
=
json
,
stream
=
True
,
headers
=
headers
)
return
requests
.
post
(
url
,
json
=
json
,
stream
=
True
,
headers
=
headers
)
...
...
test/srt/test_openai_server.py
View file @
0d4f3a9f
...
@@ -13,7 +13,10 @@ class TestOpenAIServer(unittest.TestCase):
...
@@ -13,7 +13,10 @@ class TestOpenAIServer(unittest.TestCase):
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
cls
.
model
=
MODEL_NAME_FOR_TEST
cls
.
model
=
MODEL_NAME_FOR_TEST
cls
.
base_url
=
f
"http://localhost:30000"
cls
.
base_url
=
f
"http://localhost:30000"
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
300
)
cls
.
api_key
=
"sk-123456"
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
300
,
api_key
=
cls
.
api_key
)
cls
.
base_url
+=
"/v1"
cls
.
base_url
+=
"/v1"
@
classmethod
@
classmethod
...
@@ -21,7 +24,7 @@ class TestOpenAIServer(unittest.TestCase):
...
@@ -21,7 +24,7 @@ class TestOpenAIServer(unittest.TestCase):
kill_child_process
(
cls
.
process
.
pid
)
kill_child_process
(
cls
.
process
.
pid
)
def
run_completion
(
self
,
echo
,
logprobs
,
use_list_input
):
def
run_completion
(
self
,
echo
,
logprobs
,
use_list_input
):
client
=
openai
.
Client
(
api_key
=
"EMPTY"
,
base_url
=
self
.
base_url
)
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
prompt
=
"The capital of France is"
prompt
=
"The capital of France is"
if
use_list_input
:
if
use_list_input
:
...
@@ -63,7 +66,7 @@ class TestOpenAIServer(unittest.TestCase):
...
@@ -63,7 +66,7 @@ class TestOpenAIServer(unittest.TestCase):
assert
response
.
usage
.
total_tokens
>
0
assert
response
.
usage
.
total_tokens
>
0
def
run_completion_stream
(
self
,
echo
,
logprobs
):
def
run_completion_stream
(
self
,
echo
,
logprobs
):
client
=
openai
.
Client
(
api_key
=
"EMPTY"
,
base_url
=
self
.
base_url
)
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
prompt
=
"The capital of France is"
prompt
=
"The capital of France is"
generator
=
client
.
completions
.
create
(
generator
=
client
.
completions
.
create
(
model
=
self
.
model
,
model
=
self
.
model
,
...
@@ -102,7 +105,7 @@ class TestOpenAIServer(unittest.TestCase):
...
@@ -102,7 +105,7 @@ class TestOpenAIServer(unittest.TestCase):
assert
response
.
usage
.
total_tokens
>
0
assert
response
.
usage
.
total_tokens
>
0
def
run_chat_completion
(
self
,
logprobs
):
def
run_chat_completion
(
self
,
logprobs
):
client
=
openai
.
Client
(
api_key
=
"EMPTY"
,
base_url
=
self
.
base_url
)
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
response
=
client
.
chat
.
completions
.
create
(
response
=
client
.
chat
.
completions
.
create
(
model
=
self
.
model
,
model
=
self
.
model
,
messages
=
[
messages
=
[
...
@@ -135,7 +138,7 @@ class TestOpenAIServer(unittest.TestCase):
...
@@ -135,7 +138,7 @@ class TestOpenAIServer(unittest.TestCase):
assert
response
.
usage
.
total_tokens
>
0
assert
response
.
usage
.
total_tokens
>
0
def
run_chat_completion_stream
(
self
,
logprobs
):
def
run_chat_completion_stream
(
self
,
logprobs
):
client
=
openai
.
Client
(
api_key
=
"EMPTY"
,
base_url
=
self
.
base_url
)
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
generator
=
client
.
chat
.
completions
.
create
(
generator
=
client
.
chat
.
completions
.
create
(
model
=
self
.
model
,
model
=
self
.
model
,
messages
=
[
messages
=
[
...
@@ -186,7 +189,7 @@ class TestOpenAIServer(unittest.TestCase):
...
@@ -186,7 +189,7 @@ class TestOpenAIServer(unittest.TestCase):
self
.
run_chat_completion_stream
(
logprobs
)
self
.
run_chat_completion_stream
(
logprobs
)
def
test_regex
(
self
):
def
test_regex
(
self
):
client
=
openai
.
Client
(
api_key
=
"EMPTY"
,
base_url
=
self
.
base_url
)
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
regex
=
(
regex
=
(
r
"""\{\n"""
r
"""\{\n"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment