Unverified Commit d33955d2 authored by KCFindstr's avatar KCFindstr Committed by GitHub
Browse files

Properly return error response in vertex_generate HTTP endpoint (#5956)

parent 6fc17596
...@@ -675,6 +675,8 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque ...@@ -675,6 +675,8 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque
**(vertex_req.parameters or {}), **(vertex_req.parameters or {}),
) )
ret = await generate_request(req, raw_request) ret = await generate_request(req, raw_request)
if isinstance(ret, Response):
return ret
return ORJSONResponse({"predictions": ret}) return ORJSONResponse({"predictions": ret})
......
...@@ -3,6 +3,7 @@ python3 -m unittest test_vertex_endpoint.TestVertexEndpoint.test_vertex_generate ...@@ -3,6 +3,7 @@ python3 -m unittest test_vertex_endpoint.TestVertexEndpoint.test_vertex_generate
""" """
import unittest import unittest
from http import HTTPStatus
import requests import requests
...@@ -49,6 +50,15 @@ class TestVertexEndpoint(CustomTestCase): ...@@ -49,6 +50,15 @@ class TestVertexEndpoint(CustomTestCase):
for parameters in [None, {"sampling_params": {"max_new_tokens": 4}}]: for parameters in [None, {"sampling_params": {"max_new_tokens": 4}}]:
self.run_generate(parameters) self.run_generate(parameters)
def test_vertex_generate_fail(self):
data = {
"instances": [
{"prompt": "The capital of France is"},
],
}
response = requests.post(self.base_url + "/vertex_generate", json=data)
assert response.status_code == HTTPStatus.BAD_REQUEST
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment