Commit 73fcebf7 authored by thomwolf's avatar thomwolf
Browse files

update serving command

parent 15dda5ea
......@@ -38,9 +38,9 @@ from setuptools import find_packages, setup
extras = {
'serving': ['uvicorn', 'fastapi'],
'serving-tf': ['uvicorn', 'fastapi', 'tensorflow'],
'serving-torch': ['uvicorn', 'fastapi', 'torch']
'serving': ['pydantic', 'uvicorn', 'fastapi'],
'serving-tf': ['pydantic', 'uvicorn', 'fastapi', 'tensorflow'],
'serving-torch': ['pydantic', 'uvicorn', 'fastapi', 'torch']
}
extras['all'] = [package for package in extras.values()]
......
......@@ -3,9 +3,9 @@ from argparse import ArgumentParser
from transformers.commands.download import DownloadCommand
from transformers.commands.run import RunCommand
from transformers.commands.serving import ServeCommand
from transformers.commands.user import UserCommands
from transformers.commands.convert import ConvertCommand
from transformers.commands.serving import ServeCommand
if __name__ == '__main__':
parser = ArgumentParser('Transformers CLI tool', usage='transformers-cli <command> [<args>]')
......
from argparse import ArgumentParser, Namespace
from typing import List, Optional, Union, Any
from fastapi import FastAPI, HTTPException, Body
from logging import getLogger
from pydantic import BaseModel
from uvicorn import run
import logging
try:
from uvicorn import run
from fastapi import FastAPI, HTTPException, Body
from pydantic import BaseModel
_serve_dependancies_installed = True
except (ImportError, AttributeError):
BaseModel = object
Body = lambda *x, **y: None
_serve_dependancies_installed = False
from transformers import Pipeline
from transformers.commands import BaseTransformersCLICommand
from transformers.pipelines import SUPPORTED_TASKS, pipeline
logger = logging.getLogger('transformers-cli/serving')
def serve_command_factory(args: Namespace):
"""
......@@ -70,20 +77,24 @@ class ServeCommand(BaseTransformersCLICommand):
serve_parser.set_defaults(func=serve_command_factory)
def __init__(self, pipeline: Pipeline, host: str, port: int):
self._logger = getLogger('transformers-cli/serving')
self._pipeline = pipeline
self._logger.info('Serving model over {}:{}'.format(host, port))
self._host = host
self._port = port
self._app = FastAPI()
# Register routes
self._app.add_api_route('/', self.model_info, response_model=ServeModelInfoResult, methods=['GET'])
self._app.add_api_route('/tokenize', self.tokenize, response_model=ServeTokenizeResult, methods=['POST'])
self._app.add_api_route('/detokenize', self.detokenize, response_model=ServeDeTokenizeResult, methods=['POST'])
self._app.add_api_route('/forward', self.forward, response_model=ServeForwardResult, methods=['POST'])
if not _serve_dependancies_installed:
raise ImportError("Using serve command requires FastAPI and unicorn. "
"Please install transformers with [serving]: pip install transformers[serving]."
"Or install FastAPI and unicorn separatly.")
else:
logger.info('Serving model over {}:{}'.format(host, port))
self._app = FastAPI()
# Register routes
self._app.add_api_route('/', self.model_info, response_model=ServeModelInfoResult, methods=['GET'])
self._app.add_api_route('/tokenize', self.tokenize, response_model=ServeTokenizeResult, methods=['POST'])
self._app.add_api_route('/detokenize', self.detokenize, response_model=ServeDeTokenizeResult, methods=['POST'])
self._app.add_api_route('/forward', self.forward, response_model=ServeForwardResult, methods=['POST'])
def run(self):
run(self._app, host=self._host, port=self._port)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment