Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
dbe17293
Unverified
Commit
dbe17293
authored
Nov 24, 2024
by
Henry Hyeonmok Ko
Committed by
GitHub
Nov 24, 2024
Browse files
Merged three native APIs into one: get_server_info (#2152)
parent
84a1698d
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
74 additions
and
119 deletions
+74
-119
benchmark/json_schema/bench_sglang.py
benchmark/json_schema/bench_sglang.py
+1
-1
docs/backend/native_api.ipynb
docs/backend/native_api.ipynb
+20
-62
python/sglang/__init__.py
python/sglang/__init__.py
+2
-2
python/sglang/api.py
python/sglang/api.py
+2
-2
python/sglang/lang/backend/base_backend.py
python/sglang/lang/backend/base_backend.py
+1
-1
python/sglang/lang/backend/runtime_endpoint.py
python/sglang/lang/backend/runtime_endpoint.py
+2
-2
python/sglang/srt/server.py
python/sglang/srt/server.py
+27
-40
rust/src/server.rs
rust/src/server.rs
+4
-4
test/srt/test_data_parallelism.py
test/srt/test_data_parallelism.py
+3
-2
test/srt/test_srt_endpoint.py
test/srt/test_srt_endpoint.py
+12
-3
No files found.
benchmark/json_schema/bench_sglang.py
View file @
dbe17293
...
@@ -113,7 +113,7 @@ def main(args):
...
@@ -113,7 +113,7 @@ def main(args):
# Compute accuracy
# Compute accuracy
tokenizer
=
get_tokenizer
(
tokenizer
=
get_tokenizer
(
global_config
.
default_backend
.
get_server_
args
()[
"tokenizer_path"
]
global_config
.
default_backend
.
get_server_
info
()[
"tokenizer_path"
]
)
)
output_jsons
=
[
state
[
"json_output"
]
for
state
in
states
]
output_jsons
=
[
state
[
"json_output"
]
for
state
in
states
]
num_output_tokens
=
sum
(
len
(
tokenizer
.
encode
(
x
))
for
x
in
output_jsons
)
num_output_tokens
=
sum
(
len
(
tokenizer
.
encode
(
x
))
for
x
in
output_jsons
)
...
...
docs/backend/native_api.ipynb
View file @
dbe17293
...
@@ -9,13 +9,11 @@
...
@@ -9,13 +9,11 @@
"Apart from the OpenAI compatible APIs, the SGLang Runtime also provides its native server APIs. We introduce these following APIs:\n",
"Apart from the OpenAI compatible APIs, the SGLang Runtime also provides its native server APIs. We introduce these following APIs:\n",
"\n",
"\n",
"- `/generate` (text generation model)\n",
"- `/generate` (text generation model)\n",
"- `/get_server_args`\n",
"- `/get_model_info`\n",
"- `/get_model_info`\n",
"- `/get_server_info`\n",
"- `/health`\n",
"- `/health`\n",
"- `/health_generate`\n",
"- `/health_generate`\n",
"- `/flush_cache`\n",
"- `/flush_cache`\n",
"- `/get_memory_pool_size`\n",
"- `/get_max_total_num_tokens`\n",
"- `/update_weights`\n",
"- `/update_weights`\n",
"- `/encode`(embedding model)\n",
"- `/encode`(embedding model)\n",
"- `/classify`(reward model)\n",
"- `/classify`(reward model)\n",
...
@@ -75,26 +73,6 @@
...
@@ -75,26 +73,6 @@
"print_highlight(response.json())"
"print_highlight(response.json())"
]
]
},
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Get Server Args\n",
"Get the arguments of a server."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"url = \"http://localhost:30010/get_server_args\"\n",
"\n",
"response = requests.get(url)\n",
"print_highlight(response.json())"
]
},
{
{
"cell_type": "markdown",
"cell_type": "markdown",
"metadata": {},
"metadata": {},
...
@@ -127,9 +105,12 @@
...
@@ -127,9 +105,12 @@
"cell_type": "markdown",
"cell_type": "markdown",
"metadata": {},
"metadata": {},
"source": [
"source": [
"## Health Check\n",
"## Get Server Info\n",
"- `/health`: Check the health of the server.\n",
"Gets the server information including CLI arguments, token limits, and memory pool sizes.\n",
"- `/health_generate`: Check the health of the server by generating one token."
"- Note: `get_server_info` merges the following deprecated endpoints:\n",
" - `get_server_args`\n",
" - `get_memory_pool_size` \n",
" - `get_max_total_num_tokens`"
]
]
},
},
{
{
...
@@ -138,19 +119,9 @@
...
@@ -138,19 +119,9 @@
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
"
url = \"http://localhost:30010/health_generate\"
\n",
"
# get_server_info
\n",
"\n",
"\n",
"response = requests.get(url)\n",
"url = \"http://localhost:30010/get_server_info\"\n",
"print_highlight(response.text)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"url = \"http://localhost:30010/health\"\n",
"\n",
"\n",
"response = requests.get(url)\n",
"response = requests.get(url)\n",
"print_highlight(response.text)"
"print_highlight(response.text)"
...
@@ -160,9 +131,9 @@
...
@@ -160,9 +131,9 @@
"cell_type": "markdown",
"cell_type": "markdown",
"metadata": {},
"metadata": {},
"source": [
"source": [
"##
Flush Cache
\n",
"##
Health Check
\n",
"\n",
"
- `/health`: Check the health of the server.
\n",
"
Flush the radix cache. It will be automatically triggered when the model weights are updated by the `/update_weights` API
."
"
- `/health_generate`: Check the health of the server by generating one token
."
]
]
},
},
{
{
...
@@ -171,32 +142,19 @@
...
@@ -171,32 +142,19 @@
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
"# flush cache\n",
"url = \"http://localhost:30010/health_generate\"\n",
"\n",
"url = \"http://localhost:30010/flush_cache\"\n",
"\n",
"\n",
"response = requests.
pos
t(url)\n",
"response = requests.
ge
t(url)\n",
"print_highlight(response.text)"
"print_highlight(response.text)"
]
]
},
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Get Memory Pool Size\n",
"\n",
"Get the memory pool size in number of tokens.\n"
]
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count": null,
"execution_count": null,
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
"# get_memory_pool_size\n",
"url = \"http://localhost:30010/health\"\n",
"\n",
"url = \"http://localhost:30010/get_memory_pool_size\"\n",
"\n",
"\n",
"response = requests.get(url)\n",
"response = requests.get(url)\n",
"print_highlight(response.text)"
"print_highlight(response.text)"
...
@@ -206,9 +164,9 @@
...
@@ -206,9 +164,9 @@
"cell_type": "markdown",
"cell_type": "markdown",
"metadata": {},
"metadata": {},
"source": [
"source": [
"##
Get Maximum Total Number of Tokens
\n",
"##
Flush Cache
\n",
"\n",
"\n",
"
Exposes the maximum number of tokens SGLang can handle bas
ed
on
the
current configuration
."
"
Flush the radix cache. It will be automatically triggered when the model weights are updat
ed
by
the
`/update_weights` API
."
]
]
},
},
{
{
...
@@ -217,11 +175,11 @@
...
@@ -217,11 +175,11 @@
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
"#
get_max_total_num_tokens
\n",
"#
flush cache
\n",
"\n",
"\n",
"url = \"http://localhost:30010/
get_max_total_num_tokens
\"\n",
"url = \"http://localhost:30010/
flush_cache
\"\n",
"\n",
"\n",
"response = requests.
ge
t(url)\n",
"response = requests.
pos
t(url)\n",
"print_highlight(response.text)"
"print_highlight(response.text)"
]
]
},
},
...
...
python/sglang/__init__.py
View file @
dbe17293
...
@@ -11,7 +11,7 @@ from sglang.api import (
...
@@ -11,7 +11,7 @@ from sglang.api import (
gen
,
gen
,
gen_int
,
gen_int
,
gen_string
,
gen_string
,
get_server_
args
,
get_server_
info
,
image
,
image
,
select
,
select
,
set_default_backend
,
set_default_backend
,
...
@@ -41,7 +41,7 @@ __all__ = [
...
@@ -41,7 +41,7 @@ __all__ = [
"gen"
,
"gen"
,
"gen_int"
,
"gen_int"
,
"gen_string"
,
"gen_string"
,
"get_server_
args
"
,
"get_server_
info
"
,
"image"
,
"image"
,
"select"
,
"select"
,
"set_default_backend"
,
"set_default_backend"
,
...
...
python/sglang/api.py
View file @
dbe17293
...
@@ -65,7 +65,7 @@ def flush_cache(backend: Optional[BaseBackend] = None):
...
@@ -65,7 +65,7 @@ def flush_cache(backend: Optional[BaseBackend] = None):
return
backend
.
flush_cache
()
return
backend
.
flush_cache
()
def
get_server_
args
(
backend
:
Optional
[
BaseBackend
]
=
None
):
def
get_server_
info
(
backend
:
Optional
[
BaseBackend
]
=
None
):
backend
=
backend
or
global_config
.
default_backend
backend
=
backend
or
global_config
.
default_backend
if
backend
is
None
:
if
backend
is
None
:
return
None
return
None
...
@@ -73,7 +73,7 @@ def get_server_args(backend: Optional[BaseBackend] = None):
...
@@ -73,7 +73,7 @@ def get_server_args(backend: Optional[BaseBackend] = None):
# If backend is Runtime
# If backend is Runtime
if
hasattr
(
backend
,
"endpoint"
):
if
hasattr
(
backend
,
"endpoint"
):
backend
=
backend
.
endpoint
backend
=
backend
.
endpoint
return
backend
.
get_server_
args
()
return
backend
.
get_server_
info
()
def
gen
(
def
gen
(
...
...
python/sglang/lang/backend/base_backend.py
View file @
dbe17293
...
@@ -78,5 +78,5 @@ class BaseBackend:
...
@@ -78,5 +78,5 @@ class BaseBackend:
def
flush_cache
(
self
):
def
flush_cache
(
self
):
pass
pass
def
get_server_
args
(
self
):
def
get_server_
info
(
self
):
pass
pass
python/sglang/lang/backend/runtime_endpoint.py
View file @
dbe17293
...
@@ -58,9 +58,9 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -58,9 +58,9 @@ class RuntimeEndpoint(BaseBackend):
)
)
self
.
_assert_success
(
res
)
self
.
_assert_success
(
res
)
def
get_server_
args
(
self
):
def
get_server_
info
(
self
):
res
=
http_request
(
res
=
http_request
(
self
.
base_url
+
"/get_server_
args
"
,
self
.
base_url
+
"/get_server_
info
"
,
api_key
=
self
.
api_key
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
verify
=
self
.
verify
,
)
)
...
...
python/sglang/srt/server.py
View file @
dbe17293
...
@@ -146,10 +146,15 @@ async def get_model_info():
...
@@ -146,10 +146,15 @@ async def get_model_info():
return
result
return
result
@
app
.
get
(
"/get_server_args"
)
@
app
.
get
(
"/get_server_info"
)
async
def
get_server_args
():
async
def
get_server_info
():
"""Get the server arguments."""
try
:
return
dataclasses
.
asdict
(
tokenizer_manager
.
server_args
)
return
await
_get_server_info
()
except
Exception
as
e
:
return
ORJSONResponse
(
{
"error"
:
{
"message"
:
str
(
e
)}},
status_code
=
HTTPStatus
.
BAD_REQUEST
)
@
app
.
post
(
"/flush_cache"
)
@
app
.
post
(
"/flush_cache"
)
...
@@ -185,30 +190,6 @@ async def stop_profile():
...
@@ -185,30 +190,6 @@ async def stop_profile():
)
)
@
app
.
get
(
"/get_max_total_num_tokens"
)
async
def
get_max_total_num_tokens
():
try
:
return
{
"max_total_num_tokens"
:
_get_max_total_num_tokens
()}
except
Exception
as
e
:
return
ORJSONResponse
(
{
"error"
:
{
"message"
:
str
(
e
)}},
status_code
=
HTTPStatus
.
BAD_REQUEST
)
@
app
.
api_route
(
"/get_memory_pool_size"
,
methods
=
[
"GET"
,
"POST"
])
async
def
get_memory_pool_size
():
"""Get the memory pool size in number of tokens"""
try
:
ret
=
await
tokenizer_manager
.
get_memory_pool_size
()
return
ret
except
Exception
as
e
:
return
ORJSONResponse
(
{
"error"
:
{
"message"
:
str
(
e
)}},
status_code
=
HTTPStatus
.
BAD_REQUEST
)
@
app
.
post
(
"/update_weights"
)
@
app
.
post
(
"/update_weights"
)
@
time_func_latency
@
time_func_latency
async
def
update_weights
(
obj
:
UpdateWeightReqInput
,
request
:
Request
):
async
def
update_weights
(
obj
:
UpdateWeightReqInput
,
request
:
Request
):
...
@@ -542,8 +523,12 @@ def launch_server(
...
@@ -542,8 +523,12 @@ def launch_server(
t
.
join
()
t
.
join
()
def
_get_max_total_num_tokens
():
async
def
_get_server_info
():
return
_max_total_num_tokens
return
{
**
dataclasses
.
asdict
(
tokenizer_manager
.
server_args
),
# server args
"memory_pool_size"
:
await
tokenizer_manager
.
get_memory_pool_size
(),
# memory pool size
"max_total_num_tokens"
:
_max_total_num_tokens
,
# max total num tokens
}
def
_set_envs_and_config
(
server_args
:
ServerArgs
):
def
_set_envs_and_config
(
server_args
:
ServerArgs
):
...
@@ -787,14 +772,16 @@ class Runtime:
...
@@ -787,14 +772,16 @@ class Runtime:
response
=
requests
.
post
(
self
.
url
+
"/encode"
,
json
=
json_data
)
response
=
requests
.
post
(
self
.
url
+
"/encode"
,
json
=
json_data
)
return
json
.
dumps
(
response
.
json
())
return
json
.
dumps
(
response
.
json
())
def
get_max_total_num_tokens
(
self
):
async
def
get_server_info
(
self
):
response
=
requests
.
get
(
f
"
{
self
.
url
}
/get_max_total_num_tokens"
)
async
with
aiohttp
.
ClientSession
()
as
session
:
if
response
.
status_code
==
200
:
async
with
session
.
get
(
f
"
{
self
.
url
}
/get_server_info"
)
as
response
:
return
response
.
json
()[
"max_total_num_tokens"
]
if
response
.
status
==
200
:
else
:
return
await
response
.
json
()
raise
RuntimeError
(
else
:
f
"Failed to get max tokens.
{
response
.
json
()[
'error'
][
'message'
]
}
"
error_data
=
await
response
.
json
()
)
raise
RuntimeError
(
f
"Failed to get server info.
{
error_data
[
'error'
][
'message'
]
}
"
)
def
__del__
(
self
):
def
__del__
(
self
):
self
.
shutdown
()
self
.
shutdown
()
...
@@ -946,5 +933,5 @@ class Engine:
...
@@ -946,5 +933,5 @@ class Engine:
loop
=
asyncio
.
get_event_loop
()
loop
=
asyncio
.
get_event_loop
()
return
loop
.
run_until_complete
(
encode_request
(
obj
,
None
))
return
loop
.
run_until_complete
(
encode_request
(
obj
,
None
))
def
get_
max_total_num_tokens
(
self
):
async
def
get_
server_info
(
self
):
return
_get_max_total_num_tokens
()
return
await
_get_server_info
()
rust/src/server.rs
View file @
dbe17293
...
@@ -66,14 +66,14 @@ async fn health_generate(data: web::Data<AppState>) -> impl Responder {
...
@@ -66,14 +66,14 @@ async fn health_generate(data: web::Data<AppState>) -> impl Responder {
forward_request
(
&
data
.client
,
worker_url
,
"/health_generate"
.to_string
())
.await
forward_request
(
&
data
.client
,
worker_url
,
"/health_generate"
.to_string
())
.await
}
}
#[get(
"/get_server_
args
"
)]
#[get(
"/get_server_
info
"
)]
async
fn
get_server_
args
(
data
:
web
::
Data
<
AppState
>
)
->
impl
Responder
{
async
fn
get_server_
info
(
data
:
web
::
Data
<
AppState
>
)
->
impl
Responder
{
let
worker_url
=
match
data
.router
.get_first
()
{
let
worker_url
=
match
data
.router
.get_first
()
{
Some
(
url
)
=>
url
,
Some
(
url
)
=>
url
,
None
=>
return
HttpResponse
::
InternalServerError
()
.finish
(),
None
=>
return
HttpResponse
::
InternalServerError
()
.finish
(),
};
};
forward_request
(
&
data
.client
,
worker_url
,
"/get_server_
args
"
.to_string
())
.await
forward_request
(
&
data
.client
,
worker_url
,
"/get_server_
info
"
.to_string
())
.await
}
}
#[get(
"/v1/models"
)]
#[get(
"/v1/models"
)]
...
@@ -153,7 +153,7 @@ pub async fn startup(
...
@@ -153,7 +153,7 @@ pub async fn startup(
.service
(
get_model_info
)
.service
(
get_model_info
)
.service
(
health
)
.service
(
health
)
.service
(
health_generate
)
.service
(
health_generate
)
.service
(
get_server_
args
)
.service
(
get_server_
info
)
})
})
.bind
((
host
,
port
))
?
.bind
((
host
,
port
))
?
.run
()
.run
()
...
...
test/srt/test_data_parallelism.py
View file @
dbe17293
...
@@ -63,12 +63,13 @@ class TestDataParallelism(unittest.TestCase):
...
@@ -63,12 +63,13 @@ class TestDataParallelism(unittest.TestCase):
assert
response
.
status_code
==
200
assert
response
.
status_code
==
200
def
test_get_memory_pool_size
(
self
):
def
test_get_memory_pool_size
(
self
):
response
=
requests
.
get
(
self
.
base_url
+
"/get_memory_pool_size"
)
# use `get_server_info` instead since `get_memory_pool_size` is merged into `get_server_info`
response
=
requests
.
get
(
self
.
base_url
+
"/get_server_info"
)
assert
response
.
status_code
==
200
assert
response
.
status_code
==
200
time
.
sleep
(
5
)
time
.
sleep
(
5
)
response
=
requests
.
get
(
self
.
base_url
+
"/get_
memory_pool_size
"
)
response
=
requests
.
get
(
self
.
base_url
+
"/get_
server_info
"
)
assert
response
.
status_code
==
200
assert
response
.
status_code
==
200
...
...
test/srt/test_srt_endpoint.py
View file @
dbe17293
...
@@ -154,9 +154,18 @@ class TestSRTEndpoint(unittest.TestCase):
...
@@ -154,9 +154,18 @@ class TestSRTEndpoint(unittest.TestCase):
self
.
assertEqual
(
res
[
"meta_info"
][
"completion_tokens"
],
new_tokens
)
self
.
assertEqual
(
res
[
"meta_info"
][
"completion_tokens"
],
new_tokens
)
self
.
assertEqual
(
len
(
res
[
"meta_info"
][
"output_token_logprobs"
]),
new_tokens
)
self
.
assertEqual
(
len
(
res
[
"meta_info"
][
"output_token_logprobs"
]),
new_tokens
)
def
test_get_memory_pool_size
(
self
):
def
test_get_server_info
(
self
):
response
=
requests
.
post
(
self
.
base_url
+
"/get_memory_pool_size"
)
response
=
requests
.
get
(
self
.
base_url
+
"/get_server_info"
)
self
.
assertIsInstance
(
response
.
json
(),
int
)
response_json
=
response
.
json
()
max_total_num_tokens
=
response_json
[
"max_total_num_tokens"
]
self
.
assertIsInstance
(
max_total_num_tokens
,
int
)
memory_pool_size
=
response_json
[
"memory_pool_size"
]
self
.
assertIsInstance
(
memory_pool_size
,
int
)
attention_backend
=
response_json
[
"attention_backend"
]
self
.
assertIsInstance
(
attention_backend
,
str
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment