Unverified Commit d33955d2 authored by KCFindstr's avatar KCFindstr Committed by GitHub
Browse files

Properly return error response in vertex_generate HTTP endpoint (#5956)

parent 6fc17596
......@@ -675,6 +675,8 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque
**(vertex_req.parameters or {}),
)
ret = await generate_request(req, raw_request)
if isinstance(ret, Response):
return ret
return ORJSONResponse({"predictions": ret})
......
......@@ -3,6 +3,7 @@ python3 -m unittest test_vertex_endpoint.TestVertexEndpoint.test_vertex_generate
"""
import unittest
from http import HTTPStatus
import requests
......@@ -49,6 +50,15 @@ class TestVertexEndpoint(CustomTestCase):
for parameters in [None, {"sampling_params": {"max_new_tokens": 4}}]:
self.run_generate(parameters)
def test_vertex_generate_fail(self):
data = {
"instances": [
{"prompt": "The capital of France is"},
],
}
response = requests.post(self.base_url + "/vertex_generate", json=data)
assert response.status_code == HTTPStatus.BAD_REQUEST
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment