Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
dbe17293
"...git@developer.sourcefind.cn:OpenDAS/torch-spline-conv.git" did not exist on "ecc4df35d606830fe79b77bb997c358f8489e0fa"
Unverified
Commit
dbe17293
authored
Nov 24, 2024
by
Henry Hyeonmok Ko
Committed by
GitHub
Nov 24, 2024
Browse files
Merged three native APIs into one: get_server_info (#2152)
parent
84a1698d
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
74 additions
and
119 deletions
+74
-119
benchmark/json_schema/bench_sglang.py
benchmark/json_schema/bench_sglang.py
+1
-1
docs/backend/native_api.ipynb
docs/backend/native_api.ipynb
+20
-62
python/sglang/__init__.py
python/sglang/__init__.py
+2
-2
python/sglang/api.py
python/sglang/api.py
+2
-2
python/sglang/lang/backend/base_backend.py
python/sglang/lang/backend/base_backend.py
+1
-1
python/sglang/lang/backend/runtime_endpoint.py
python/sglang/lang/backend/runtime_endpoint.py
+2
-2
python/sglang/srt/server.py
python/sglang/srt/server.py
+27
-40
rust/src/server.rs
rust/src/server.rs
+4
-4
test/srt/test_data_parallelism.py
test/srt/test_data_parallelism.py
+3
-2
test/srt/test_srt_endpoint.py
test/srt/test_srt_endpoint.py
+12
-3
No files found.
benchmark/json_schema/bench_sglang.py
View file @
dbe17293
...
@@ -113,7 +113,7 @@ def main(args):
...
@@ -113,7 +113,7 @@ def main(args):
# Compute accuracy
# Compute accuracy
tokenizer
=
get_tokenizer
(
tokenizer
=
get_tokenizer
(
global_config
.
default_backend
.
get_server_
args
()[
"tokenizer_path"
]
global_config
.
default_backend
.
get_server_
info
()[
"tokenizer_path"
]
)
)
output_jsons
=
[
state
[
"json_output"
]
for
state
in
states
]
output_jsons
=
[
state
[
"json_output"
]
for
state
in
states
]
num_output_tokens
=
sum
(
len
(
tokenizer
.
encode
(
x
))
for
x
in
output_jsons
)
num_output_tokens
=
sum
(
len
(
tokenizer
.
encode
(
x
))
for
x
in
output_jsons
)
...
...
docs/backend/native_api.ipynb
View file @
dbe17293
...
@@ -9,13 +9,11 @@
...
@@ -9,13 +9,11 @@
"Apart from the OpenAI compatible APIs, the SGLang Runtime also provides its native server APIs. We introduce these following APIs:\n",
"Apart from the OpenAI compatible APIs, the SGLang Runtime also provides its native server APIs. We introduce these following APIs:\n",
"\n",
"\n",
"- `/generate` (text generation model)\n",
"- `/generate` (text generation model)\n",
"- `/get_server_args`\n",
"- `/get_model_info`\n",
"- `/get_model_info`\n",
"- `/get_server_info`\n",
"- `/health`\n",
"- `/health`\n",
"- `/health_generate`\n",
"- `/health_generate`\n",
"- `/flush_cache`\n",
"- `/flush_cache`\n",
"- `/get_memory_pool_size`\n",
"- `/get_max_total_num_tokens`\n",
"- `/update_weights`\n",
"- `/update_weights`\n",
"- `/encode`(embedding model)\n",
"- `/encode`(embedding model)\n",
"- `/classify`(reward model)\n",
"- `/classify`(reward model)\n",
...
@@ -75,26 +73,6 @@
...
@@ -75,26 +73,6 @@
"print_highlight(response.json())"
"print_highlight(response.json())"
]
]
},
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Get Server Args\n",
"Get the arguments of a server."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"url = \"http://localhost:30010/get_server_args\"\n",
"\n",
"response = requests.get(url)\n",
"print_highlight(response.json())"
]
},
{
{
"cell_type": "markdown",
"cell_type": "markdown",
"metadata": {},
"metadata": {},
...
@@ -127,9 +105,12 @@
...
@@ -127,9 +105,12 @@
"cell_type": "markdown",
"cell_type": "markdown",
"metadata": {},
"metadata": {},
"source": [
"source": [
"## Health Check\n",
"## Get Server Info\n",
"- `/health`: Check the health of the server.\n",
"Gets the server information including CLI arguments, token limits, and memory pool sizes.\n",
"- `/health_generate`: Check the health of the server by generating one token."
"- Note: `get_server_info` merges the following deprecated endpoints:\n",
" - `get_server_args`\n",
" - `get_memory_pool_size` \n",
" - `get_max_total_num_tokens`"
]
]
},
},
{
{
...
@@ -138,19 +119,9 @@
...
@@ -138,19 +119,9 @@
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
"
url = \"http://localhost:30010/health_generate\"
\n",
"
# get_server_info
\n",
"\n",
"\n",
"response = requests.get(url)\n",
"url = \"http://localhost:30010/get_server_info\"\n",
"print_highlight(response.text)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"url = \"http://localhost:30010/health\"\n",
"\n",
"\n",
"response = requests.get(url)\n",
"response = requests.get(url)\n",
"print_highlight(response.text)"
"print_highlight(response.text)"
...
@@ -160,9 +131,9 @@
...
@@ -160,9 +131,9 @@
"cell_type": "markdown",
"cell_type": "markdown",
"metadata": {},
"metadata": {},
"source": [
"source": [
"##
Flush Cache
\n",
"##
Health Check
\n",
"\n",
"
- `/health`: Check the health of the server.
\n",
"
Flush the radix cache. It will be automatically triggered when the model weights are updated by the `/update_weights` API
."
"
- `/health_generate`: Check the health of the server by generating one token
."
]
]
},
},
{
{
...
@@ -171,32 +142,19 @@
...
@@ -171,32 +142,19 @@
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
"# flush cache\n",
"url = \"http://localhost:30010/health_generate\"\n",
"\n",
"url = \"http://localhost:30010/flush_cache\"\n",
"\n",
"\n",
"response = requests.
pos
t(url)\n",
"response = requests.
ge
t(url)\n",
"print_highlight(response.text)"
"print_highlight(response.text)"
]
]
},
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Get Memory Pool Size\n",
"\n",
"Get the memory pool size in number of tokens.\n"
]
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count": null,
"execution_count": null,
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
"# get_memory_pool_size\n",
"url = \"http://localhost:30010/health\"\n",
"\n",
"url = \"http://localhost:30010/get_memory_pool_size\"\n",
"\n",
"\n",
"response = requests.get(url)\n",
"response = requests.get(url)\n",
"print_highlight(response.text)"
"print_highlight(response.text)"
...
@@ -206,9 +164,9 @@
...
@@ -206,9 +164,9 @@
"cell_type": "markdown",
"cell_type": "markdown",
"metadata": {},
"metadata": {},
"source": [
"source": [
"##
Get Maximum Total Number of Tokens
\n",
"##
Flush Cache
\n",
"\n",
"\n",
"
Exposes the maximum number of tokens SGLang can handle bas
ed
on
the
current configuration
."
"
Flush the radix cache. It will be automatically triggered when the model weights are updat
ed
by
the
`/update_weights` API
."
]
]
},
},
{
{
...
@@ -217,11 +175,11 @@
...
@@ -217,11 +175,11 @@
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
"#
get_max_total_num_tokens
\n",
"#
flush cache
\n",
"\n",
"\n",
"url = \"http://localhost:30010/
get_max_total_num_tokens
\"\n",
"url = \"http://localhost:30010/
flush_cache
\"\n",
"\n",
"\n",
"response = requests.
ge
t(url)\n",
"response = requests.
pos
t(url)\n",
"print_highlight(response.text)"
"print_highlight(response.text)"
]
]
},
},
...
...
python/sglang/__init__.py
View file @
dbe17293
...
@@ -11,7 +11,7 @@ from sglang.api import (
...
@@ -11,7 +11,7 @@ from sglang.api import (
gen
,
gen
,
gen_int
,
gen_int
,
gen_string
,
gen_string
,
get_server_
args
,
get_server_
info
,
image
,
image
,
select
,
select
,
set_default_backend
,
set_default_backend
,
...
@@ -41,7 +41,7 @@ __all__ = [
...
@@ -41,7 +41,7 @@ __all__ = [
"gen"
,
"gen"
,
"gen_int"
,
"gen_int"
,
"gen_string"
,
"gen_string"
,
"get_server_
args
"
,
"get_server_
info
"
,
"image"
,
"image"
,
"select"
,
"select"
,
"set_default_backend"
,
"set_default_backend"
,
...
...
python/sglang/api.py
View file @
dbe17293
...
@@ -65,7 +65,7 @@ def flush_cache(backend: Optional[BaseBackend] = None):
...
@@ -65,7 +65,7 @@ def flush_cache(backend: Optional[BaseBackend] = None):
return
backend
.
flush_cache
()
return
backend
.
flush_cache
()
def
get_server_
args
(
backend
:
Optional
[
BaseBackend
]
=
None
):
def
get_server_
info
(
backend
:
Optional
[
BaseBackend
]
=
None
):
backend
=
backend
or
global_config
.
default_backend
backend
=
backend
or
global_config
.
default_backend
if
backend
is
None
:
if
backend
is
None
:
return
None
return
None
...
@@ -73,7 +73,7 @@ def get_server_args(backend: Optional[BaseBackend] = None):
...
@@ -73,7 +73,7 @@ def get_server_args(backend: Optional[BaseBackend] = None):
# If backend is Runtime
# If backend is Runtime
if
hasattr
(
backend
,
"endpoint"
):
if
hasattr
(
backend
,
"endpoint"
):
backend
=
backend
.
endpoint
backend
=
backend
.
endpoint
return
backend
.
get_server_
args
()
return
backend
.
get_server_
info
()
def
gen
(
def
gen
(
...
...
python/sglang/lang/backend/base_backend.py
View file @
dbe17293
...
@@ -78,5 +78,5 @@ class BaseBackend:
...
@@ -78,5 +78,5 @@ class BaseBackend:
def
flush_cache
(
self
):
def
flush_cache
(
self
):
pass
pass
def
get_server_
args
(
self
):
def
get_server_
info
(
self
):
pass
pass
python/sglang/lang/backend/runtime_endpoint.py
View file @
dbe17293
...
@@ -58,9 +58,9 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -58,9 +58,9 @@ class RuntimeEndpoint(BaseBackend):
)
)
self
.
_assert_success
(
res
)
self
.
_assert_success
(
res
)
def
get_server_
args
(
self
):
def
get_server_
info
(
self
):
res
=
http_request
(
res
=
http_request
(
self
.
base_url
+
"/get_server_
args
"
,
self
.
base_url
+
"/get_server_
info
"
,
api_key
=
self
.
api_key
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
verify
=
self
.
verify
,
)
)
...
...
python/sglang/srt/server.py
View file @
dbe17293
...
@@ -146,10 +146,15 @@ async def get_model_info():
...
@@ -146,10 +146,15 @@ async def get_model_info():
return
result
return
result
@
app
.
get
(
"/get_server_args"
)
@
app
.
get
(
"/get_server_info"
)
async
def
get_server_args
():
async
def
get_server_info
():
"""Get the server arguments."""
try
:
return
dataclasses
.
asdict
(
tokenizer_manager
.
server_args
)
return
await
_get_server_info
()
except
Exception
as
e
:
return
ORJSONResponse
(
{
"error"
:
{
"message"
:
str
(
e
)}},
status_code
=
HTTPStatus
.
BAD_REQUEST
)
@
app
.
post
(
"/flush_cache"
)
@
app
.
post
(
"/flush_cache"
)
...
@@ -185,30 +190,6 @@ async def stop_profile():
...
@@ -185,30 +190,6 @@ async def stop_profile():
)
)
@
app
.
get
(
"/get_max_total_num_tokens"
)
async
def
get_max_total_num_tokens
():
try
:
return
{
"max_total_num_tokens"
:
_get_max_total_num_tokens
()}
except
Exception
as
e
:
return
ORJSONResponse
(
{
"error"
:
{
"message"
:
str
(
e
)}},
status_code
=
HTTPStatus
.
BAD_REQUEST
)
@
app
.
api_route
(
"/get_memory_pool_size"
,
methods
=
[
"GET"
,
"POST"
])
async
def
get_memory_pool_size
():
"""Get the memory pool size in number of tokens"""
try
:
ret
=
await
tokenizer_manager
.
get_memory_pool_size
()
return
ret
except
Exception
as
e
:
return
ORJSONResponse
(
{
"error"
:
{
"message"
:
str
(
e
)}},
status_code
=
HTTPStatus
.
BAD_REQUEST
)
@
app
.
post
(
"/update_weights"
)
@
app
.
post
(
"/update_weights"
)
@
time_func_latency
@
time_func_latency
async
def
update_weights
(
obj
:
UpdateWeightReqInput
,
request
:
Request
):
async
def
update_weights
(
obj
:
UpdateWeightReqInput
,
request
:
Request
):
...
@@ -542,8 +523,12 @@ def launch_server(
...
@@ -542,8 +523,12 @@ def launch_server(
t
.
join
()
t
.
join
()
def
_get_max_total_num_tokens
():
async
def
_get_server_info
():
return
_max_total_num_tokens
return
{
**
dataclasses
.
asdict
(
tokenizer_manager
.
server_args
),
# server args
"memory_pool_size"
:
await
tokenizer_manager
.
get_memory_pool_size
(),
# memory pool size
"max_total_num_tokens"
:
_max_total_num_tokens
,
# max total num tokens
}
def
_set_envs_and_config
(
server_args
:
ServerArgs
):
def
_set_envs_and_config
(
server_args
:
ServerArgs
):
...
@@ -787,14 +772,16 @@ class Runtime:
...
@@ -787,14 +772,16 @@ class Runtime:
response
=
requests
.
post
(
self
.
url
+
"/encode"
,
json
=
json_data
)
response
=
requests
.
post
(
self
.
url
+
"/encode"
,
json
=
json_data
)
return
json
.
dumps
(
response
.
json
())
return
json
.
dumps
(
response
.
json
())
def
get_max_total_num_tokens
(
self
):
async
def
get_server_info
(
self
):
response
=
requests
.
get
(
f
"
{
self
.
url
}
/get_max_total_num_tokens"
)
async
with
aiohttp
.
ClientSession
()
as
session
:
if
response
.
status_code
==
200
:
async
with
session
.
get
(
f
"
{
self
.
url
}
/get_server_info"
)
as
response
:
return
response
.
json
()[
"max_total_num_tokens"
]
if
response
.
status
==
200
:
else
:
return
await
response
.
json
()
raise
RuntimeError
(
else
:
f
"Failed to get max tokens.
{
response
.
json
()[
'error'
][
'message'
]
}
"
error_data
=
await
response
.
json
()
)
raise
RuntimeError
(
f
"Failed to get server info.
{
error_data
[
'error'
][
'message'
]
}
"
)
def
__del__
(
self
):
def
__del__
(
self
):
self
.
shutdown
()
self
.
shutdown
()
...
@@ -946,5 +933,5 @@ class Engine:
...
@@ -946,5 +933,5 @@ class Engine:
loop
=
asyncio
.
get_event_loop
()
loop
=
asyncio
.
get_event_loop
()
return
loop
.
run_until_complete
(
encode_request
(
obj
,
None
))
return
loop
.
run_until_complete
(
encode_request
(
obj
,
None
))
def
get_
max_total_num_tokens
(
self
):
async
def
get_
server_info
(
self
):
return
_get_max_total_num_tokens
()
return
await
_get_server_info
()
rust/src/server.rs
View file @
dbe17293
...
@@ -66,14 +66,14 @@ async fn health_generate(data: web::Data<AppState>) -> impl Responder {
...
@@ -66,14 +66,14 @@ async fn health_generate(data: web::Data<AppState>) -> impl Responder {
forward_request
(
&
data
.client
,
worker_url
,
"/health_generate"
.to_string
())
.await
forward_request
(
&
data
.client
,
worker_url
,
"/health_generate"
.to_string
())
.await
}
}
#[get(
"/get_server_
args
"
)]
#[get(
"/get_server_
info
"
)]
async
fn
get_server_
args
(
data
:
web
::
Data
<
AppState
>
)
->
impl
Responder
{
async
fn
get_server_
info
(
data
:
web
::
Data
<
AppState
>
)
->
impl
Responder
{
let
worker_url
=
match
data
.router
.get_first
()
{
let
worker_url
=
match
data
.router
.get_first
()
{
Some
(
url
)
=>
url
,
Some
(
url
)
=>
url
,
None
=>
return
HttpResponse
::
InternalServerError
()
.finish
(),
None
=>
return
HttpResponse
::
InternalServerError
()
.finish
(),
};
};
forward_request
(
&
data
.client
,
worker_url
,
"/get_server_
args
"
.to_string
())
.await
forward_request
(
&
data
.client
,
worker_url
,
"/get_server_
info
"
.to_string
())
.await
}
}
#[get(
"/v1/models"
)]
#[get(
"/v1/models"
)]
...
@@ -153,7 +153,7 @@ pub async fn startup(
...
@@ -153,7 +153,7 @@ pub async fn startup(
.service
(
get_model_info
)
.service
(
get_model_info
)
.service
(
health
)
.service
(
health
)
.service
(
health_generate
)
.service
(
health_generate
)
.service
(
get_server_
args
)
.service
(
get_server_
info
)
})
})
.bind
((
host
,
port
))
?
.bind
((
host
,
port
))
?
.run
()
.run
()
...
...
test/srt/test_data_parallelism.py
View file @
dbe17293
...
@@ -63,12 +63,13 @@ class TestDataParallelism(unittest.TestCase):
...
@@ -63,12 +63,13 @@ class TestDataParallelism(unittest.TestCase):
assert
response
.
status_code
==
200
assert
response
.
status_code
==
200
def
test_get_memory_pool_size
(
self
):
def
test_get_memory_pool_size
(
self
):
response
=
requests
.
get
(
self
.
base_url
+
"/get_memory_pool_size"
)
# use `get_server_info` instead since `get_memory_pool_size` is merged into `get_server_info`
response
=
requests
.
get
(
self
.
base_url
+
"/get_server_info"
)
assert
response
.
status_code
==
200
assert
response
.
status_code
==
200
time
.
sleep
(
5
)
time
.
sleep
(
5
)
response
=
requests
.
get
(
self
.
base_url
+
"/get_
memory_pool_size
"
)
response
=
requests
.
get
(
self
.
base_url
+
"/get_
server_info
"
)
assert
response
.
status_code
==
200
assert
response
.
status_code
==
200
...
...
test/srt/test_srt_endpoint.py
View file @
dbe17293
...
@@ -154,9 +154,18 @@ class TestSRTEndpoint(unittest.TestCase):
...
@@ -154,9 +154,18 @@ class TestSRTEndpoint(unittest.TestCase):
self
.
assertEqual
(
res
[
"meta_info"
][
"completion_tokens"
],
new_tokens
)
self
.
assertEqual
(
res
[
"meta_info"
][
"completion_tokens"
],
new_tokens
)
self
.
assertEqual
(
len
(
res
[
"meta_info"
][
"output_token_logprobs"
]),
new_tokens
)
self
.
assertEqual
(
len
(
res
[
"meta_info"
][
"output_token_logprobs"
]),
new_tokens
)
def
test_get_memory_pool_size
(
self
):
def
test_get_server_info
(
self
):
response
=
requests
.
post
(
self
.
base_url
+
"/get_memory_pool_size"
)
response
=
requests
.
get
(
self
.
base_url
+
"/get_server_info"
)
self
.
assertIsInstance
(
response
.
json
(),
int
)
response_json
=
response
.
json
()
max_total_num_tokens
=
response_json
[
"max_total_num_tokens"
]
self
.
assertIsInstance
(
max_total_num_tokens
,
int
)
memory_pool_size
=
response_json
[
"memory_pool_size"
]
self
.
assertIsInstance
(
memory_pool_size
,
int
)
attention_backend
=
response_json
[
"attention_backend"
]
self
.
assertIsInstance
(
attention_backend
,
str
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment