"ts/webui/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "af5551f929b1906345abd5866959c12c83c7bdbd"
Commit 908cd5ea authored by Morgan Funtowicz's avatar Morgan Funtowicz Committed by Lysandre Debut
Browse files

Make forward asynchrone to avoid long computation timing out.


Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>
parent 6e6c8c52
import logging import logging
from argparse import ArgumentParser, Namespace from argparse import ArgumentParser, Namespace
from typing import Any, List, Optional, Union from typing import Any, List, Optional
from starlette.responses import JSONResponse
from transformers import Pipeline from transformers import Pipeline
from transformers.commands import BaseTransformersCLICommand from transformers.commands import BaseTransformersCLICommand
...@@ -10,6 +12,7 @@ from transformers.pipelines import SUPPORTED_TASKS, pipeline ...@@ -10,6 +12,7 @@ from transformers.pipelines import SUPPORTED_TASKS, pipeline
try: try:
from uvicorn import run from uvicorn import run
from fastapi import FastAPI, HTTPException, Body from fastapi import FastAPI, HTTPException, Body
from fastapi.routing import APIRoute
from pydantic import BaseModel from pydantic import BaseModel
_serve_dependancies_installed = True _serve_dependancies_installed = True
...@@ -37,7 +40,7 @@ def serve_command_factory(args: Namespace): ...@@ -37,7 +40,7 @@ def serve_command_factory(args: Namespace):
tokenizer=args.tokenizer, tokenizer=args.tokenizer,
device=args.device, device=args.device,
) )
return ServeCommand(nlp, args.host, args.port) return ServeCommand(nlp, args.host, args.port, args.workers)
class ServeModelInfoResult(BaseModel): class ServeModelInfoResult(BaseModel):
...@@ -89,6 +92,7 @@ class ServeCommand(BaseTransformersCLICommand): ...@@ -89,6 +92,7 @@ class ServeCommand(BaseTransformersCLICommand):
) )
serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.") serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.")
serve_parser.add_argument("--port", type=int, default=8888, help="Port the serving will listen to.") serve_parser.add_argument("--port", type=int, default=8888, help="Port the serving will listen to.")
serve_parser.add_argument("--workers", type=int, default=1, help="Number of http workers")
serve_parser.add_argument("--model", type=str, help="Model's name or path to stored model.") serve_parser.add_argument("--model", type=str, help="Model's name or path to stored model.")
serve_parser.add_argument("--config", type=str, help="Model's config name or path to stored model.") serve_parser.add_argument("--config", type=str, help="Model's config name or path to stored model.")
serve_parser.add_argument("--tokenizer", type=str, help="Tokenizer name to use.") serve_parser.add_argument("--tokenizer", type=str, help="Tokenizer name to use.")
...@@ -100,12 +104,14 @@ class ServeCommand(BaseTransformersCLICommand): ...@@ -100,12 +104,14 @@ class ServeCommand(BaseTransformersCLICommand):
) )
serve_parser.set_defaults(func=serve_command_factory) serve_parser.set_defaults(func=serve_command_factory)
def __init__(self, pipeline: Pipeline, host: str, port: int): def __init__(self, pipeline: Pipeline, host: str, port: int, workers: int):
self._pipeline = pipeline self._pipeline = pipeline
self._host = host self.host = host
self._port = port self.port = port
self.workers = workers
if not _serve_dependancies_installed: if not _serve_dependancies_installed:
raise RuntimeError( raise RuntimeError(
"Using serve command requires FastAPI and unicorn. " "Using serve command requires FastAPI and unicorn. "
...@@ -114,18 +120,42 @@ class ServeCommand(BaseTransformersCLICommand): ...@@ -114,18 +120,42 @@ class ServeCommand(BaseTransformersCLICommand):
) )
else: else:
logger.info("Serving model over {}:{}".format(host, port)) logger.info("Serving model over {}:{}".format(host, port))
self._app = FastAPI() self._app = FastAPI(
routes=[
# Register routes APIRoute(
self._app.add_api_route("/", self.model_info, response_model=ServeModelInfoResult, methods=["GET"]) "/",
self._app.add_api_route("/tokenize", self.tokenize, response_model=ServeTokenizeResult, methods=["POST"]) self.model_info,
self._app.add_api_route( response_model=ServeModelInfoResult,
"/detokenize", self.detokenize, response_model=ServeDeTokenizeResult, methods=["POST"] response_class=JSONResponse,
methods=["GET"],
),
APIRoute(
"/tokenize",
self.tokenize,
response_model=ServeTokenizeResult,
response_class=JSONResponse,
methods=["POST"],
),
APIRoute(
"/detokenize",
self.detokenize,
response_model=ServeDeTokenizeResult,
response_class=JSONResponse,
methods=["POST"],
),
APIRoute(
"/forward",
self.forward,
response_model=ServeForwardResult,
response_class=JSONResponse,
methods=["POST"],
),
],
timeout=600,
) )
self._app.add_api_route("/forward", self.forward, response_model=ServeForwardResult, methods=["POST"])
def run(self): def run(self):
run(self._app, host=self._host, port=self._port) run(self._app, host=self.host, port=self.port, workers=self.workers)
def model_info(self): def model_info(self):
return ServeModelInfoResult(infos=vars(self._pipeline.model.config)) return ServeModelInfoResult(infos=vars(self._pipeline.model.config))
...@@ -166,7 +196,7 @@ class ServeCommand(BaseTransformersCLICommand): ...@@ -166,7 +196,7 @@ class ServeCommand(BaseTransformersCLICommand):
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail={"model": "", "error": str(e)}) raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
def forward(self, inputs: Union[str, dict, List[str], List[int], List[dict]] = Body(None, embed=True)): async def forward(self, inputs=Body(None, embed=True)):
""" """
**inputs**: **inputs**:
**attention_mask**: **attention_mask**:
......
...@@ -28,8 +28,9 @@ from . import __version__ ...@@ -28,8 +28,9 @@ from . import __version__
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
try: try:
if os.environ.get("USE_TORCH", 'AUTO').upper() in ("1", "ON", "YES", "AUTO") and \ USE_TF = os.environ.get("USE_TF", "AUTO").upper()
os.environ.get("USE_TF", 'AUTO').upper() not in ("1", "ON", "YES"): USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
if USE_TORCH in ("1", "ON", "YES", "AUTO") and USE_TF not in ("1", "ON", "YES"):
import torch import torch
_torch_available = True # pylint: disable=invalid-name _torch_available = True # pylint: disable=invalid-name
...@@ -41,8 +42,10 @@ except ImportError: ...@@ -41,8 +42,10 @@ except ImportError:
_torch_available = False # pylint: disable=invalid-name _torch_available = False # pylint: disable=invalid-name
try: try:
if os.environ.get("USE_TF", 'AUTO').upper() in ("1", "ON", "YES", "AUTO") and \ USE_TF = os.environ.get("USE_TF", "AUTO").upper()
os.environ.get("USE_TORCH", 'AUTO').upper() not in ("1", "ON", "YES"): USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
if USE_TF in ("1", "ON", "YES", "AUTO") and USE_TORCH not in ("1", "ON", "YES"):
import tensorflow as tf import tensorflow as tf
assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2 assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment