Commit 73fcebf7 authored by thomwolf's avatar thomwolf
Browse files

update serving command

parent 15dda5ea
...@@ -38,9 +38,9 @@ from setuptools import find_packages, setup ...@@ -38,9 +38,9 @@ from setuptools import find_packages, setup
extras = { extras = {
'serving': ['uvicorn', 'fastapi'], 'serving': ['pydantic', 'uvicorn', 'fastapi'],
'serving-tf': ['uvicorn', 'fastapi', 'tensorflow'], 'serving-tf': ['pydantic', 'uvicorn', 'fastapi', 'tensorflow'],
'serving-torch': ['uvicorn', 'fastapi', 'torch'] 'serving-torch': ['pydantic', 'uvicorn', 'fastapi', 'torch']
} }
extras['all'] = [package for package in extras.values()] extras['all'] = [package for package in extras.values()]
......
...@@ -3,9 +3,9 @@ from argparse import ArgumentParser ...@@ -3,9 +3,9 @@ from argparse import ArgumentParser
from transformers.commands.download import DownloadCommand from transformers.commands.download import DownloadCommand
from transformers.commands.run import RunCommand from transformers.commands.run import RunCommand
from transformers.commands.serving import ServeCommand
from transformers.commands.user import UserCommands from transformers.commands.user import UserCommands
from transformers.commands.convert import ConvertCommand from transformers.commands.convert import ConvertCommand
from transformers.commands.serving import ServeCommand
if __name__ == '__main__': if __name__ == '__main__':
parser = ArgumentParser('Transformers CLI tool', usage='transformers-cli <command> [<args>]') parser = ArgumentParser('Transformers CLI tool', usage='transformers-cli <command> [<args>]')
......
from argparse import ArgumentParser, Namespace from argparse import ArgumentParser, Namespace
from typing import List, Optional, Union, Any from typing import List, Optional, Union, Any
from fastapi import FastAPI, HTTPException, Body import logging
from logging import getLogger
try:
from pydantic import BaseModel from uvicorn import run
from uvicorn import run from fastapi import FastAPI, HTTPException, Body
from pydantic import BaseModel
_serve_dependancies_installed = True
except (ImportError, AttributeError):
BaseModel = object
Body = lambda *x, **y: None
_serve_dependancies_installed = False
from transformers import Pipeline from transformers import Pipeline
from transformers.commands import BaseTransformersCLICommand from transformers.commands import BaseTransformersCLICommand
from transformers.pipelines import SUPPORTED_TASKS, pipeline from transformers.pipelines import SUPPORTED_TASKS, pipeline
logger = logging.getLogger('transformers-cli/serving')
def serve_command_factory(args: Namespace): def serve_command_factory(args: Namespace):
""" """
...@@ -70,20 +77,24 @@ class ServeCommand(BaseTransformersCLICommand): ...@@ -70,20 +77,24 @@ class ServeCommand(BaseTransformersCLICommand):
serve_parser.set_defaults(func=serve_command_factory) serve_parser.set_defaults(func=serve_command_factory)
def __init__(self, pipeline: Pipeline, host: str, port: int): def __init__(self, pipeline: Pipeline, host: str, port: int):
self._logger = getLogger('transformers-cli/serving')
self._pipeline = pipeline self._pipeline = pipeline
self._logger.info('Serving model over {}:{}'.format(host, port))
self._host = host self._host = host
self._port = port self._port = port
self._app = FastAPI() if not _serve_dependancies_installed:
raise ImportError("Using serve command requires FastAPI and unicorn. "
# Register routes "Please install transformers with [serving]: pip install transformers[serving]."
self._app.add_api_route('/', self.model_info, response_model=ServeModelInfoResult, methods=['GET']) "Or install FastAPI and unicorn separatly.")
self._app.add_api_route('/tokenize', self.tokenize, response_model=ServeTokenizeResult, methods=['POST']) else:
self._app.add_api_route('/detokenize', self.detokenize, response_model=ServeDeTokenizeResult, methods=['POST']) logger.info('Serving model over {}:{}'.format(host, port))
self._app.add_api_route('/forward', self.forward, response_model=ServeForwardResult, methods=['POST']) self._app = FastAPI()
# Register routes
self._app.add_api_route('/', self.model_info, response_model=ServeModelInfoResult, methods=['GET'])
self._app.add_api_route('/tokenize', self.tokenize, response_model=ServeTokenizeResult, methods=['POST'])
self._app.add_api_route('/detokenize', self.detokenize, response_model=ServeDeTokenizeResult, methods=['POST'])
self._app.add_api_route('/forward', self.forward, response_model=ServeForwardResult, methods=['POST'])
def run(self): def run(self):
run(self._app, host=self._host, port=self._port) run(self._app, host=self._host, port=self._port)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment