Unverified Commit bc20e93f authored by KCFindstr's avatar KCFindstr Committed by GitHub
Browse files

[feat] Add Vertex AI compatible prediction route for /generate (#3866)

parent d3887852
"""
Usage:
python -m sglang.launch_server --model meta-llama/Llama-2-7b-hf --port 30000
python vertex_predict.py
This example shows the request and response formats of the prediction route for
Google Cloud Vertex AI Online Predictions.
Vertex AI SDK for Python is recommended for deploying models to Vertex AI
instead of a local server. After deploying the model to a Vertex AI Online
Prediction Endpoint, send requests via the Python SDK:
response = endpoint.predict(
instances=[
{"text": "The capital of France is"},
{"text": "What is a car?"},
],
parameters={"sampling_params": {"max_new_tokens": 16}},
)
print(response.predictions)
More details about get online predictions from Vertex AI can be found at
https://cloud.google.com/vertex-ai/docs/predictions/get-online-predictions.
"""
from dataclasses import dataclass
from typing import List, Optional
import requests
@dataclass
class VertexPrediction:
predictions: List
class LocalVertexEndpoint:
def __init__(self) -> None:
self.base_url = "http://127.0.0.1:30000"
def predict(self, instances: List[dict], parameters: Optional[dict] = None):
response = requests.post(
self.base_url + "/vertex_generate",
json={
"instances": instances,
"parameters": parameters,
},
)
return VertexPrediction(predictions=response.json()["predictions"])
endpoint = LocalVertexEndpoint()
# Predict with a single prompt.
response = endpoint.predict(instances=[{"text": "The capital of France is"}])
print(response.predictions)
# Predict with multiple prompts and parameters.
response = endpoint.predict(
instances=[
{"text": "The capital of France is"},
{"text": "What is a car?"},
],
parameters={"sampling_params": {"max_new_tokens": 16}},
)
print(response.predictions)
......@@ -53,6 +53,7 @@ from sglang.srt.managers.io_struct import (
ResumeMemoryOccupationReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
VertexGenerateReqInput,
)
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.metrics.func_timer import enable_func_timer
......@@ -475,6 +476,32 @@ async def sagemaker_chat_completions(raw_request: Request):
return await v1_chat_completions(_global_state.tokenizer_manager, raw_request)
## Vertex AI API
@app.post(os.environ.get("AIP_PREDICT_ROUTE", "/vertex_generate"))
async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Request):
if not vertex_req.instances:
return []
inputs = {}
for input_key in ("text", "input_ids", "input_embeds"):
if vertex_req.instances[0].get(input_key):
inputs[input_key] = [
instance.get(input_key) for instance in vertex_req.instances
]
break
image_data = [
instance.get("image_data")
for instance in vertex_req.instances
if instance.get("image_data") is not None
] or None
req = GenerateReqInput(
**inputs,
image_data=image_data,
**(vertex_req.parameters or {}),
)
ret = await generate_request(req, raw_request)
return ORJSONResponse({"predictions": ret})
def _create_error_response(e):
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
......
......@@ -568,3 +568,9 @@ class FunctionCallReqInput:
tool_call_parser: Optional[str] = (
None # Specify the parser type, e.g. 'llama3', 'qwen25', or 'mistral'. If not specified, tries all.
)
@dataclass
class VertexGenerateReqInput:
instances: List[dict]
parameters: Optional[dict] = None
......@@ -50,6 +50,7 @@ suites = {
"test_hidden_states.py",
"test_update_weights_from_disk.py",
"test_update_weights_from_tensor.py",
"test_vertex_endpoint.py",
"test_vision_chunked_prefill.py",
"test_vision_llm.py",
"test_vision_openai_server.py",
......
"""
python3 -m unittest test_vertex_endpoint.TestVertexEndpoint.test_vertex_generate
"""
import unittest
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)
class TestVertexEndpoint(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def run_generate(self, parameters):
data = {
"instances": [
{"text": "The capital of France is"},
{"text": "The capital of China is"},
],
"parameters": parameters,
}
response = requests.post(self.base_url + "/vertex_generate", json=data)
response_json = response.json()
assert len(response_json["predictions"]) == len(data["instances"])
return response_json
def test_vertex_generate(self):
for parameters in [None, {"sampling_params": {"max_new_tokens": 4}}]:
self.run_generate(parameters)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment