Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
dbe17293
Unverified
Commit
dbe17293
authored
Nov 24, 2024
by
Henry Hyeonmok Ko
Committed by
GitHub
Nov 24, 2024
Browse files
Merged three native APIs into one: get_server_info (#2152)
parent
84a1698d
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
74 additions
and
119 deletions
+74
-119
benchmark/json_schema/bench_sglang.py
benchmark/json_schema/bench_sglang.py
+1
-1
docs/backend/native_api.ipynb
docs/backend/native_api.ipynb
+20
-62
python/sglang/__init__.py
python/sglang/__init__.py
+2
-2
python/sglang/api.py
python/sglang/api.py
+2
-2
python/sglang/lang/backend/base_backend.py
python/sglang/lang/backend/base_backend.py
+1
-1
python/sglang/lang/backend/runtime_endpoint.py
python/sglang/lang/backend/runtime_endpoint.py
+2
-2
python/sglang/srt/server.py
python/sglang/srt/server.py
+27
-40
rust/src/server.rs
rust/src/server.rs
+4
-4
test/srt/test_data_parallelism.py
test/srt/test_data_parallelism.py
+3
-2
test/srt/test_srt_endpoint.py
test/srt/test_srt_endpoint.py
+12
-3
No files found.
benchmark/json_schema/bench_sglang.py
View file @
dbe17293
...
...
@@ -113,7 +113,7 @@ def main(args):
# Compute accuracy
tokenizer
=
get_tokenizer
(
global_config
.
default_backend
.
get_server_
args
()[
"tokenizer_path"
]
global_config
.
default_backend
.
get_server_
info
()[
"tokenizer_path"
]
)
output_jsons
=
[
state
[
"json_output"
]
for
state
in
states
]
num_output_tokens
=
sum
(
len
(
tokenizer
.
encode
(
x
))
for
x
in
output_jsons
)
...
...
docs/backend/native_api.ipynb
View file @
dbe17293
...
...
@@ -9,13 +9,11 @@
"Apart from the OpenAI compatible APIs, the SGLang Runtime also provides its native server APIs. We introduce these following APIs:\n",
"\n",
"- `/generate` (text generation model)\n",
"- `/get_server_args`\n",
"- `/get_model_info`\n",
"- `/get_server_info`\n",
"- `/health`\n",
"- `/health_generate`\n",
"- `/flush_cache`\n",
"- `/get_memory_pool_size`\n",
"- `/get_max_total_num_tokens`\n",
"- `/update_weights`\n",
"- `/encode`(embedding model)\n",
"- `/classify`(reward model)\n",
...
...
@@ -75,26 +73,6 @@
"print_highlight(response.json())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Get Server Args\n",
"Get the arguments of a server."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"url = \"http://localhost:30010/get_server_args\"\n",
"\n",
"response = requests.get(url)\n",
"print_highlight(response.json())"
]
},
{
"cell_type": "markdown",
"metadata": {},
...
...
@@ -127,9 +105,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Health Check\n",
"- `/health`: Check the health of the server.\n",
"- `/health_generate`: Check the health of the server by generating one token."
"## Get Server Info\n",
"Gets the server information including CLI arguments, token limits, and memory pool sizes.\n",
"- Note: `get_server_info` merges the following deprecated endpoints:\n",
" - `get_server_args`\n",
" - `get_memory_pool_size` \n",
" - `get_max_total_num_tokens`"
]
},
{
...
...
@@ -138,19 +119,9 @@
"metadata": {},
"outputs": [],
"source": [
"
url = \"http://localhost:30010/health_generate\"
\n",
"
# get_server_info
\n",
"\n",
"response = requests.get(url)\n",
"print_highlight(response.text)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"url = \"http://localhost:30010/health\"\n",
"url = \"http://localhost:30010/get_server_info\"\n",
"\n",
"response = requests.get(url)\n",
"print_highlight(response.text)"
...
...
@@ -160,9 +131,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"##
Flush Cache
\n",
"\n",
"
Flush the radix cache. It will be automatically triggered when the model weights are updated by the `/update_weights` API
."
"##
Health Check
\n",
"
- `/health`: Check the health of the server.
\n",
"
- `/health_generate`: Check the health of the server by generating one token
."
]
},
{
...
...
@@ -171,32 +142,19 @@
"metadata": {},
"outputs": [],
"source": [
"# flush cache\n",
"\n",
"url = \"http://localhost:30010/flush_cache\"\n",
"url = \"http://localhost:30010/health_generate\"\n",
"\n",
"response = requests.
pos
t(url)\n",
"response = requests.
ge
t(url)\n",
"print_highlight(response.text)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Get Memory Pool Size\n",
"\n",
"Get the memory pool size in number of tokens.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# get_memory_pool_size\n",
"\n",
"url = \"http://localhost:30010/get_memory_pool_size\"\n",
"url = \"http://localhost:30010/health\"\n",
"\n",
"response = requests.get(url)\n",
"print_highlight(response.text)"
...
...
@@ -206,9 +164,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"##
Get Maximum Total Number of Tokens
\n",
"##
Flush Cache
\n",
"\n",
"
Exposes the maximum number of tokens SGLang can handle bas
ed
on
the
current configuration
."
"
Flush the radix cache. It will be automatically triggered when the model weights are updat
ed
by
the
`/update_weights` API
."
]
},
{
...
...
@@ -217,11 +175,11 @@
"metadata": {},
"outputs": [],
"source": [
"#
get_max_total_num_tokens
\n",
"#
flush cache
\n",
"\n",
"url = \"http://localhost:30010/
get_max_total_num_tokens
\"\n",
"url = \"http://localhost:30010/
flush_cache
\"\n",
"\n",
"response = requests.
ge
t(url)\n",
"response = requests.
pos
t(url)\n",
"print_highlight(response.text)"
]
},
...
...
python/sglang/__init__.py
View file @
dbe17293
...
...
@@ -11,7 +11,7 @@ from sglang.api import (
gen
,
gen_int
,
gen_string
,
get_server_
args
,
get_server_
info
,
image
,
select
,
set_default_backend
,
...
...
@@ -41,7 +41,7 @@ __all__ = [
"gen"
,
"gen_int"
,
"gen_string"
,
"get_server_
args
"
,
"get_server_
info
"
,
"image"
,
"select"
,
"set_default_backend"
,
...
...
python/sglang/api.py
View file @
dbe17293
...
...
@@ -65,7 +65,7 @@ def flush_cache(backend: Optional[BaseBackend] = None):
return
backend
.
flush_cache
()
def
get_server_
args
(
backend
:
Optional
[
BaseBackend
]
=
None
):
def
get_server_
info
(
backend
:
Optional
[
BaseBackend
]
=
None
):
backend
=
backend
or
global_config
.
default_backend
if
backend
is
None
:
return
None
...
...
@@ -73,7 +73,7 @@ def get_server_args(backend: Optional[BaseBackend] = None):
# If backend is Runtime
if
hasattr
(
backend
,
"endpoint"
):
backend
=
backend
.
endpoint
return
backend
.
get_server_
args
()
return
backend
.
get_server_
info
()
def
gen
(
...
...
python/sglang/lang/backend/base_backend.py
View file @
dbe17293
...
...
@@ -78,5 +78,5 @@ class BaseBackend:
def
flush_cache
(
self
):
pass
def
get_server_
args
(
self
):
def
get_server_
info
(
self
):
pass
python/sglang/lang/backend/runtime_endpoint.py
View file @
dbe17293
...
...
@@ -58,9 +58,9 @@ class RuntimeEndpoint(BaseBackend):
)
self
.
_assert_success
(
res
)
def
get_server_
args
(
self
):
def
get_server_
info
(
self
):
res
=
http_request
(
self
.
base_url
+
"/get_server_
args
"
,
self
.
base_url
+
"/get_server_
info
"
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
)
...
...
python/sglang/srt/server.py
View file @
dbe17293
...
...
@@ -146,10 +146,15 @@ async def get_model_info():
return
result
@
app
.
get
(
"/get_server_args"
)
async
def
get_server_args
():
"""Get the server arguments."""
return
dataclasses
.
asdict
(
tokenizer_manager
.
server_args
)
@
app
.
get
(
"/get_server_info"
)
async
def
get_server_info
():
try
:
return
await
_get_server_info
()
except
Exception
as
e
:
return
ORJSONResponse
(
{
"error"
:
{
"message"
:
str
(
e
)}},
status_code
=
HTTPStatus
.
BAD_REQUEST
)
@
app
.
post
(
"/flush_cache"
)
...
...
@@ -185,30 +190,6 @@ async def stop_profile():
)
@
app
.
get
(
"/get_max_total_num_tokens"
)
async
def
get_max_total_num_tokens
():
try
:
return
{
"max_total_num_tokens"
:
_get_max_total_num_tokens
()}
except
Exception
as
e
:
return
ORJSONResponse
(
{
"error"
:
{
"message"
:
str
(
e
)}},
status_code
=
HTTPStatus
.
BAD_REQUEST
)
@
app
.
api_route
(
"/get_memory_pool_size"
,
methods
=
[
"GET"
,
"POST"
])
async
def
get_memory_pool_size
():
"""Get the memory pool size in number of tokens"""
try
:
ret
=
await
tokenizer_manager
.
get_memory_pool_size
()
return
ret
except
Exception
as
e
:
return
ORJSONResponse
(
{
"error"
:
{
"message"
:
str
(
e
)}},
status_code
=
HTTPStatus
.
BAD_REQUEST
)
@
app
.
post
(
"/update_weights"
)
@
time_func_latency
async
def
update_weights
(
obj
:
UpdateWeightReqInput
,
request
:
Request
):
...
...
@@ -542,8 +523,12 @@ def launch_server(
t
.
join
()
def
_get_max_total_num_tokens
():
return
_max_total_num_tokens
async
def
_get_server_info
():
return
{
**
dataclasses
.
asdict
(
tokenizer_manager
.
server_args
),
# server args
"memory_pool_size"
:
await
tokenizer_manager
.
get_memory_pool_size
(),
# memory pool size
"max_total_num_tokens"
:
_max_total_num_tokens
,
# max total num tokens
}
def
_set_envs_and_config
(
server_args
:
ServerArgs
):
...
...
@@ -787,14 +772,16 @@ class Runtime:
response
=
requests
.
post
(
self
.
url
+
"/encode"
,
json
=
json_data
)
return
json
.
dumps
(
response
.
json
())
def
get_max_total_num_tokens
(
self
):
response
=
requests
.
get
(
f
"
{
self
.
url
}
/get_max_total_num_tokens"
)
if
response
.
status_code
==
200
:
return
response
.
json
()[
"max_total_num_tokens"
]
else
:
raise
RuntimeError
(
f
"Failed to get max tokens.
{
response
.
json
()[
'error'
][
'message'
]
}
"
)
async
def
get_server_info
(
self
):
async
with
aiohttp
.
ClientSession
()
as
session
:
async
with
session
.
get
(
f
"
{
self
.
url
}
/get_server_info"
)
as
response
:
if
response
.
status
==
200
:
return
await
response
.
json
()
else
:
error_data
=
await
response
.
json
()
raise
RuntimeError
(
f
"Failed to get server info.
{
error_data
[
'error'
][
'message'
]
}
"
)
def
__del__
(
self
):
self
.
shutdown
()
...
...
@@ -946,5 +933,5 @@ class Engine:
loop
=
asyncio
.
get_event_loop
()
return
loop
.
run_until_complete
(
encode_request
(
obj
,
None
))
def
get_
max_total_num_tokens
(
self
):
return
_get_max_total_num_tokens
()
async
def
get_
server_info
(
self
):
return
await
_get_server_info
()
rust/src/server.rs
View file @
dbe17293
...
...
@@ -66,14 +66,14 @@ async fn health_generate(data: web::Data<AppState>) -> impl Responder {
forward_request
(
&
data
.client
,
worker_url
,
"/health_generate"
.to_string
())
.await
}
#[get(
"/get_server_
args
"
)]
async
fn
get_server_
args
(
data
:
web
::
Data
<
AppState
>
)
->
impl
Responder
{
#[get(
"/get_server_
info
"
)]
async
fn
get_server_
info
(
data
:
web
::
Data
<
AppState
>
)
->
impl
Responder
{
let
worker_url
=
match
data
.router
.get_first
()
{
Some
(
url
)
=>
url
,
None
=>
return
HttpResponse
::
InternalServerError
()
.finish
(),
};
forward_request
(
&
data
.client
,
worker_url
,
"/get_server_
args
"
.to_string
())
.await
forward_request
(
&
data
.client
,
worker_url
,
"/get_server_
info
"
.to_string
())
.await
}
#[get(
"/v1/models"
)]
...
...
@@ -153,7 +153,7 @@ pub async fn startup(
.service
(
get_model_info
)
.service
(
health
)
.service
(
health_generate
)
.service
(
get_server_
args
)
.service
(
get_server_
info
)
})
.bind
((
host
,
port
))
?
.run
()
...
...
test/srt/test_data_parallelism.py
View file @
dbe17293
...
...
@@ -63,12 +63,13 @@ class TestDataParallelism(unittest.TestCase):
assert
response
.
status_code
==
200
def
test_get_memory_pool_size
(
self
):
response
=
requests
.
get
(
self
.
base_url
+
"/get_memory_pool_size"
)
# use `get_server_info` instead since `get_memory_pool_size` is merged into `get_server_info`
response
=
requests
.
get
(
self
.
base_url
+
"/get_server_info"
)
assert
response
.
status_code
==
200
time
.
sleep
(
5
)
response
=
requests
.
get
(
self
.
base_url
+
"/get_
memory_pool_size
"
)
response
=
requests
.
get
(
self
.
base_url
+
"/get_
server_info
"
)
assert
response
.
status_code
==
200
...
...
test/srt/test_srt_endpoint.py
View file @
dbe17293
...
...
@@ -154,9 +154,18 @@ class TestSRTEndpoint(unittest.TestCase):
self
.
assertEqual
(
res
[
"meta_info"
][
"completion_tokens"
],
new_tokens
)
self
.
assertEqual
(
len
(
res
[
"meta_info"
][
"output_token_logprobs"
]),
new_tokens
)
def
test_get_memory_pool_size
(
self
):
response
=
requests
.
post
(
self
.
base_url
+
"/get_memory_pool_size"
)
self
.
assertIsInstance
(
response
.
json
(),
int
)
def
test_get_server_info
(
self
):
response
=
requests
.
get
(
self
.
base_url
+
"/get_server_info"
)
response_json
=
response
.
json
()
max_total_num_tokens
=
response_json
[
"max_total_num_tokens"
]
self
.
assertIsInstance
(
max_total_num_tokens
,
int
)
memory_pool_size
=
response_json
[
"memory_pool_size"
]
self
.
assertIsInstance
(
memory_pool_size
,
int
)
attention_backend
=
response_json
[
"attention_backend"
]
self
.
assertIsInstance
(
attention_backend
,
str
)
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment