Commit 908cd5ea authored by Morgan Funtowicz's avatar Morgan Funtowicz Committed by Lysandre Debut
Browse files

Make forward asynchrone to avoid long computation timing out.


Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>
parent 6e6c8c52
import logging
from argparse import ArgumentParser, Namespace
from typing import Any, List, Optional, Union
from typing import Any, List, Optional
from starlette.responses import JSONResponse
from transformers import Pipeline
from transformers.commands import BaseTransformersCLICommand
......@@ -10,6 +12,7 @@ from transformers.pipelines import SUPPORTED_TASKS, pipeline
try:
from uvicorn import run
from fastapi import FastAPI, HTTPException, Body
from fastapi.routing import APIRoute
from pydantic import BaseModel
_serve_dependancies_installed = True
......@@ -37,7 +40,7 @@ def serve_command_factory(args: Namespace):
tokenizer=args.tokenizer,
device=args.device,
)
return ServeCommand(nlp, args.host, args.port)
return ServeCommand(nlp, args.host, args.port, args.workers)
class ServeModelInfoResult(BaseModel):
......@@ -89,6 +92,7 @@ class ServeCommand(BaseTransformersCLICommand):
)
serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.")
serve_parser.add_argument("--port", type=int, default=8888, help="Port the serving will listen to.")
serve_parser.add_argument("--workers", type=int, default=1, help="Number of http workers")
serve_parser.add_argument("--model", type=str, help="Model's name or path to stored model.")
serve_parser.add_argument("--config", type=str, help="Model's config name or path to stored model.")
serve_parser.add_argument("--tokenizer", type=str, help="Tokenizer name to use.")
......@@ -100,12 +104,14 @@ class ServeCommand(BaseTransformersCLICommand):
)
serve_parser.set_defaults(func=serve_command_factory)
def __init__(self, pipeline: Pipeline, host: str, port: int):
def __init__(self, pipeline: Pipeline, host: str, port: int, workers: int):
self._pipeline = pipeline
self._host = host
self._port = port
self.host = host
self.port = port
self.workers = workers
if not _serve_dependancies_installed:
raise RuntimeError(
"Using serve command requires FastAPI and unicorn. "
......@@ -114,18 +120,42 @@ class ServeCommand(BaseTransformersCLICommand):
)
else:
logger.info("Serving model over {}:{}".format(host, port))
self._app = FastAPI()
# Register routes
self._app.add_api_route("/", self.model_info, response_model=ServeModelInfoResult, methods=["GET"])
self._app.add_api_route("/tokenize", self.tokenize, response_model=ServeTokenizeResult, methods=["POST"])
self._app.add_api_route(
"/detokenize", self.detokenize, response_model=ServeDeTokenizeResult, methods=["POST"]
self._app = FastAPI(
routes=[
APIRoute(
"/",
self.model_info,
response_model=ServeModelInfoResult,
response_class=JSONResponse,
methods=["GET"],
),
APIRoute(
"/tokenize",
self.tokenize,
response_model=ServeTokenizeResult,
response_class=JSONResponse,
methods=["POST"],
),
APIRoute(
"/detokenize",
self.detokenize,
response_model=ServeDeTokenizeResult,
response_class=JSONResponse,
methods=["POST"],
),
APIRoute(
"/forward",
self.forward,
response_model=ServeForwardResult,
response_class=JSONResponse,
methods=["POST"],
),
],
timeout=600,
)
self._app.add_api_route("/forward", self.forward, response_model=ServeForwardResult, methods=["POST"])
def run(self):
run(self._app, host=self._host, port=self._port)
run(self._app, host=self.host, port=self.port, workers=self.workers)
def model_info(self):
return ServeModelInfoResult(infos=vars(self._pipeline.model.config))
......@@ -166,7 +196,7 @@ class ServeCommand(BaseTransformersCLICommand):
except Exception as e:
raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
def forward(self, inputs: Union[str, dict, List[str], List[int], List[dict]] = Body(None, embed=True)):
async def forward(self, inputs=Body(None, embed=True)):
"""
**inputs**:
**attention_mask**:
......
......@@ -28,8 +28,9 @@ from . import __version__
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
try:
if os.environ.get("USE_TORCH", 'AUTO').upper() in ("1", "ON", "YES", "AUTO") and \
os.environ.get("USE_TF", 'AUTO').upper() not in ("1", "ON", "YES"):
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
if USE_TORCH in ("1", "ON", "YES", "AUTO") and USE_TF not in ("1", "ON", "YES"):
import torch
_torch_available = True # pylint: disable=invalid-name
......@@ -41,8 +42,10 @@ except ImportError:
_torch_available = False # pylint: disable=invalid-name
try:
if os.environ.get("USE_TF", 'AUTO').upper() in ("1", "ON", "YES", "AUTO") and \
os.environ.get("USE_TORCH", 'AUTO').upper() not in ("1", "ON", "YES"):
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
if USE_TF in ("1", "ON", "YES", "AUTO") and USE_TORCH not in ("1", "ON", "YES"):
import tensorflow as tf
assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment