Commit a096e2a8 authored by Morgan Funtowicz's avatar Morgan Funtowicz
Browse files

WIP serving through HTTP internally using pipelines.

parent 43a4e1bb
from argparse import ArgumentParser, Namespace from argparse import ArgumentParser, Namespace
from typing import List, Optional, Union, Any from typing import List, Optional, Union, Any
import torch
from fastapi import FastAPI, HTTPException, Body from fastapi import FastAPI, HTTPException, Body
from logging import getLogger from logging import getLogger
from pydantic import BaseModel from pydantic import BaseModel
from uvicorn import run from uvicorn import run
from transformers import AutoModel, AutoTokenizer, AutoConfig from transformers import Pipeline
from transformers.commands import BaseTransformersCLICommand from transformers.commands import BaseTransformersCLICommand
from transformers.pipelines import SUPPORTED_TASKS, pipeline
def serve_command_factory(args: Namespace): def serve_command_factory(args: Namespace):
...@@ -17,7 +17,8 @@ def serve_command_factory(args: Namespace): ...@@ -17,7 +17,8 @@ def serve_command_factory(args: Namespace):
Factory function used to instantiate serving server from provided command line arguments. Factory function used to instantiate serving server from provided command line arguments.
:return: ServeCommand :return: ServeCommand
""" """
return ServeCommand(args.host, args.port, args.model, args.graphql) nlp = pipeline(args.task, args.model)
return ServeCommand(nlp, args.host, args.port, args.model, args.graphql)
class ServeResult(BaseModel): class ServeResult(BaseModel):
...@@ -53,8 +54,6 @@ class ServeForwardResult(ServeResult): ...@@ -53,8 +54,6 @@ class ServeForwardResult(ServeResult):
""" """
Forward result model Forward result model
""" """
tokens: List[str]
tokens_ids: List[int]
output: Any output: Any
...@@ -68,19 +67,18 @@ class ServeCommand(BaseTransformersCLICommand): ...@@ -68,19 +67,18 @@ class ServeCommand(BaseTransformersCLICommand):
:return: :return:
""" """
serve_parser = parser.add_parser('serve', help='CLI tool to run inference requests through REST and GraphQL endpoints.') serve_parser = parser.add_parser('serve', help='CLI tool to run inference requests through REST and GraphQL endpoints.')
serve_parser.add_argument('--task', type=str, choices=SUPPORTED_TASKS.keys(), help='The task to run the pipeline on')
serve_parser.add_argument('--device', type=int, default=-1, help='Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)')
serve_parser.add_argument('--host', type=str, default='localhost', help='Interface the server will listen on.') serve_parser.add_argument('--host', type=str, default='localhost', help='Interface the server will listen on.')
serve_parser.add_argument('--port', type=int, default=8888, help='Port the serving will listen to.') serve_parser.add_argument('--port', type=int, default=8888, help='Port the serving will listen to.')
serve_parser.add_argument('--model', type=str, required=True, help='Model\'s name or path to stored model to infer from.') serve_parser.add_argument('--model', type=str, required=True, help='Model\'s name or path to stored model to infer from.')
serve_parser.add_argument('--graphql', action='store_true', default=False, help='Enable GraphQL endpoints.') serve_parser.add_argument('--graphql', action='store_true', default=False, help='Enable GraphQL endpoints.')
serve_parser.set_defaults(func=serve_command_factory) serve_parser.set_defaults(func=serve_command_factory)
def __init__(self, host: str, port: int, model: str, graphql: bool): def __init__(self, pipeline: Pipeline, host: str, port: int, model: str, graphql: bool):
self._logger = getLogger('transformers-cli/serving') self._logger = getLogger('transformers-cli/serving')
self._logger.info('Loading model {}'.format(model)) self._pipeline = pipeline
self._model_name = model
self._model = AutoModel.from_pretrained(model)
self._tokenizer = AutoTokenizer.from_pretrained(model)
self._logger.info('Serving model over {}:{}'.format(host, port)) self._logger.info('Serving model over {}:{}'.format(host, port))
self._host = host self._host = host
...@@ -97,7 +95,7 @@ class ServeCommand(BaseTransformersCLICommand): ...@@ -97,7 +95,7 @@ class ServeCommand(BaseTransformersCLICommand):
run(self._app, host=self._host, port=self._port) run(self._app, host=self._host, port=self._port)
def model_info(self): def model_info(self):
return ServeModelInfoResult(model=self._model_name, infos=vars(self._model.config)) return ServeModelInfoResult(model='', infos=vars(self._pipeline.model.config))
def tokenize(self, text_input: str = Body(None, embed=True), return_ids: bool = Body(False, embed=True)): def tokenize(self, text_input: str = Body(None, embed=True), return_ids: bool = Body(False, embed=True)):
""" """
...@@ -106,16 +104,16 @@ class ServeCommand(BaseTransformersCLICommand): ...@@ -106,16 +104,16 @@ class ServeCommand(BaseTransformersCLICommand):
- **return_ids**: Boolean flags indicating if the tokens have to be converted to their integer mapping. - **return_ids**: Boolean flags indicating if the tokens have to be converted to their integer mapping.
""" """
try: try:
tokens_txt = self._tokenizer.tokenize(text_input) tokens_txt = self._pipeline.tokenizer.tokenize(text_input)
if return_ids: if return_ids:
tokens_ids = self._tokenizer.convert_tokens_to_ids(tokens_txt) tokens_ids = self._pipeline.tokenizer.convert_tokens_to_ids(tokens_txt)
return ServeTokenizeResult(model=self._model_name, tokens=tokens_txt, tokens_ids=tokens_ids) return ServeTokenizeResult(model='', tokens=tokens_txt, tokens_ids=tokens_ids)
else: else:
return ServeTokenizeResult(model=self._model_name, tokens=tokens_txt) return ServeTokenizeResult(model='', tokens=tokens_txt)
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail={"model": self._model_name, "error": str(e)}) raise HTTPException(status_code=500, detail={"model": '', "error": str(e)})
def detokenize(self, tokens_ids: List[int] = Body(None, embed=True), def detokenize(self, tokens_ids: List[int] = Body(None, embed=True),
skip_special_tokens: bool = Body(False, embed=True), skip_special_tokens: bool = Body(False, embed=True),
...@@ -127,14 +125,12 @@ class ServeCommand(BaseTransformersCLICommand): ...@@ -127,14 +125,12 @@ class ServeCommand(BaseTransformersCLICommand):
- **cleanup_tokenization_spaces**: Flag indicating to remove all leading/trailing spaces and intermediate ones. - **cleanup_tokenization_spaces**: Flag indicating to remove all leading/trailing spaces and intermediate ones.
""" """
try: try:
decoded_str = self._tokenizer.decode(tokens_ids, skip_special_tokens, cleanup_tokenization_spaces) decoded_str = self._pipeline.tokenizer.decode(tokens_ids, skip_special_tokens, cleanup_tokenization_spaces)
return ServeDeTokenizeResult(model=self._model_name, text=decoded_str) return ServeDeTokenizeResult(model='', text=decoded_str)
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail={"model": self._model_name, "error": str(e)}) raise HTTPException(status_code=500, detail={"model": '', "error": str(e)})
def forward(self, inputs: Union[str, List[str], List[int]] = Body(None, embed=True), def forward(self, inputs: Union[str, dict, List[str], List[int], List[dict]] = Body(None, embed=True)):
attention_mask: Optional[List[int]] = Body(None, embed=True),
tokens_type_ids: Optional[List[int]] = Body(None, embed=True)):
""" """
**inputs**: **inputs**:
**attention_mask**: **attention_mask**:
...@@ -143,34 +139,13 @@ class ServeCommand(BaseTransformersCLICommand): ...@@ -143,34 +139,13 @@ class ServeCommand(BaseTransformersCLICommand):
# Check we don't have empty string # Check we don't have empty string
if len(inputs) == 0: if len(inputs) == 0:
return ServeForwardResult(model=self._model_name, output=[], attention=[]) return ServeForwardResult(model='', output=[], attention=[])
if isinstance(inputs, str):
inputs_tokens = self._tokenizer.tokenize(inputs)
inputs_ids = self._tokenizer.convert_tokens_to_ids(inputs_tokens)
elif isinstance(inputs, List):
if isinstance(inputs[0], str):
inputs_tokens = inputs
inputs_ids = self._tokenizer.convert_tokens_to_ids(inputs_tokens)
elif isinstance(inputs[0], int):
inputs_tokens = []
inputs_ids = inputs
else:
error_msg = "inputs should be string, [str] of [int] (got {})".format(type(inputs[0]))
raise HTTPException(423, detail={"error": error_msg})
else:
error_msg = "inputs should be string, [str] of [int] (got {})".format(type(inputs))
raise HTTPException(423, detail={"error": error_msg})
try: try:
# Forward through the model # Forward through the model
t_input_ids = torch.tensor(inputs_ids).unsqueeze(0) output = self._pipeline(inputs)
output = self._model(t_input_ids, attention_mask, tokens_type_ids)
return ServeForwardResult( return ServeForwardResult(
model=self._model_name, tokens=inputs_tokens, model='', output=output
tokens_ids=inputs_ids, output=output[0].tolist()
) )
except Exception as e: except Exception as e:
raise HTTPException(500, {"error": str(e)}) raise HTTPException(500, {"error": str(e)})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment