Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
bc20e93f
Unverified
Commit
bc20e93f
authored
Feb 27, 2025
by
KCFindstr
Committed by
GitHub
Feb 27, 2025
Browse files
[feat] Add Vertex AI compatible prediction route for /generate (#3866)
parent
d3887852
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
152 additions
and
0 deletions
+152
-0
examples/runtime/vertex_predict.py
examples/runtime/vertex_predict.py
+66
-0
python/sglang/srt/entrypoints/http_server.py
python/sglang/srt/entrypoints/http_server.py
+27
-0
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+6
-0
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_vertex_endpoint.py
test/srt/test_vertex_endpoint.py
+52
-0
No files found.
examples/runtime/vertex_predict.py
0 → 100644
View file @
bc20e93f
"""
Usage:
python -m sglang.launch_server --model meta-llama/Llama-2-7b-hf --port 30000
python vertex_predict.py
This example shows the request and response formats of the prediction route for
Google Cloud Vertex AI Online Predictions.
Vertex AI SDK for Python is recommended for deploying models to Vertex AI
instead of a local server. After deploying the model to a Vertex AI Online
Prediction Endpoint, send requests via the Python SDK:
response = endpoint.predict(
instances=[
{"text": "The capital of France is"},
{"text": "What is a car?"},
],
parameters={"sampling_params": {"max_new_tokens": 16}},
)
print(response.predictions)
More details about get online predictions from Vertex AI can be found at
https://cloud.google.com/vertex-ai/docs/predictions/get-online-predictions.
"""
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
import
requests
@
dataclass
class
VertexPrediction
:
predictions
:
List
class
LocalVertexEndpoint
:
def
__init__
(
self
)
->
None
:
self
.
base_url
=
"http://127.0.0.1:30000"
def
predict
(
self
,
instances
:
List
[
dict
],
parameters
:
Optional
[
dict
]
=
None
):
response
=
requests
.
post
(
self
.
base_url
+
"/vertex_generate"
,
json
=
{
"instances"
:
instances
,
"parameters"
:
parameters
,
},
)
return
VertexPrediction
(
predictions
=
response
.
json
()[
"predictions"
])
endpoint
=
LocalVertexEndpoint
()
# Predict with a single prompt.
response
=
endpoint
.
predict
(
instances
=
[{
"text"
:
"The capital of France is"
}])
print
(
response
.
predictions
)
# Predict with multiple prompts and parameters.
response
=
endpoint
.
predict
(
instances
=
[
{
"text"
:
"The capital of France is"
},
{
"text"
:
"What is a car?"
},
],
parameters
=
{
"sampling_params"
:
{
"max_new_tokens"
:
16
}},
)
print
(
response
.
predictions
)
python/sglang/srt/entrypoints/http_server.py
View file @
bc20e93f
...
@@ -53,6 +53,7 @@ from sglang.srt.managers.io_struct import (
...
@@ -53,6 +53,7 @@ from sglang.srt.managers.io_struct import (
ResumeMemoryOccupationReqInput
,
ResumeMemoryOccupationReqInput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightsFromDistributedReqInput
,
UpdateWeightsFromDistributedReqInput
,
VertexGenerateReqInput
,
)
)
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
from
sglang.srt.metrics.func_timer
import
enable_func_timer
from
sglang.srt.metrics.func_timer
import
enable_func_timer
...
@@ -475,6 +476,32 @@ async def sagemaker_chat_completions(raw_request: Request):
...
@@ -475,6 +476,32 @@ async def sagemaker_chat_completions(raw_request: Request):
return
await
v1_chat_completions
(
_global_state
.
tokenizer_manager
,
raw_request
)
return
await
v1_chat_completions
(
_global_state
.
tokenizer_manager
,
raw_request
)
## Vertex AI API
@
app
.
post
(
os
.
environ
.
get
(
"AIP_PREDICT_ROUTE"
,
"/vertex_generate"
))
async
def
vertex_generate
(
vertex_req
:
VertexGenerateReqInput
,
raw_request
:
Request
):
if
not
vertex_req
.
instances
:
return
[]
inputs
=
{}
for
input_key
in
(
"text"
,
"input_ids"
,
"input_embeds"
):
if
vertex_req
.
instances
[
0
].
get
(
input_key
):
inputs
[
input_key
]
=
[
instance
.
get
(
input_key
)
for
instance
in
vertex_req
.
instances
]
break
image_data
=
[
instance
.
get
(
"image_data"
)
for
instance
in
vertex_req
.
instances
if
instance
.
get
(
"image_data"
)
is
not
None
]
or
None
req
=
GenerateReqInput
(
**
inputs
,
image_data
=
image_data
,
**
(
vertex_req
.
parameters
or
{}),
)
ret
=
await
generate_request
(
req
,
raw_request
)
return
ORJSONResponse
({
"predictions"
:
ret
})
def
_create_error_response
(
e
):
def
_create_error_response
(
e
):
return
ORJSONResponse
(
return
ORJSONResponse
(
{
"error"
:
{
"message"
:
str
(
e
)}},
status_code
=
HTTPStatus
.
BAD_REQUEST
{
"error"
:
{
"message"
:
str
(
e
)}},
status_code
=
HTTPStatus
.
BAD_REQUEST
...
...
python/sglang/srt/managers/io_struct.py
View file @
bc20e93f
...
@@ -568,3 +568,9 @@ class FunctionCallReqInput:
...
@@ -568,3 +568,9 @@ class FunctionCallReqInput:
tool_call_parser
:
Optional
[
str
]
=
(
tool_call_parser
:
Optional
[
str
]
=
(
None
# Specify the parser type, e.g. 'llama3', 'qwen25', or 'mistral'. If not specified, tries all.
None
# Specify the parser type, e.g. 'llama3', 'qwen25', or 'mistral'. If not specified, tries all.
)
)
@
dataclass
class
VertexGenerateReqInput
:
instances
:
List
[
dict
]
parameters
:
Optional
[
dict
]
=
None
test/srt/run_suite.py
View file @
bc20e93f
...
@@ -50,6 +50,7 @@ suites = {
...
@@ -50,6 +50,7 @@ suites = {
"test_hidden_states.py"
,
"test_hidden_states.py"
,
"test_update_weights_from_disk.py"
,
"test_update_weights_from_disk.py"
,
"test_update_weights_from_tensor.py"
,
"test_update_weights_from_tensor.py"
,
"test_vertex_endpoint.py"
,
"test_vision_chunked_prefill.py"
,
"test_vision_chunked_prefill.py"
,
"test_vision_llm.py"
,
"test_vision_llm.py"
,
"test_vision_openai_server.py"
,
"test_vision_openai_server.py"
,
...
...
test/srt/test_vertex_endpoint.py
0 → 100644
View file @
bc20e93f
"""
python3 -m unittest test_vertex_endpoint.TestVertexEndpoint.test_vertex_generate
"""
import
unittest
import
requests
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.test_utils
import
(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
)
class
TestVertexEndpoint
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
)
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
def
run_generate
(
self
,
parameters
):
data
=
{
"instances"
:
[
{
"text"
:
"The capital of France is"
},
{
"text"
:
"The capital of China is"
},
],
"parameters"
:
parameters
,
}
response
=
requests
.
post
(
self
.
base_url
+
"/vertex_generate"
,
json
=
data
)
response_json
=
response
.
json
()
assert
len
(
response_json
[
"predictions"
])
==
len
(
data
[
"instances"
])
return
response_json
def
test_vertex_generate
(
self
):
for
parameters
in
[
None
,
{
"sampling_params"
:
{
"max_new_tokens"
:
4
}}]:
self
.
run_generate
(
parameters
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment