Commit 3b29322d authored by Morgan Funtowicz's avatar Morgan Funtowicz
Browse files

Expose all the pipeline argument on serve command.

parent fc624716
...@@ -17,25 +17,18 @@ def serve_command_factory(args: Namespace): ...@@ -17,25 +17,18 @@ def serve_command_factory(args: Namespace):
Factory function used to instantiate serving server from provided command line arguments. Factory function used to instantiate serving server from provided command line arguments.
:return: ServeCommand :return: ServeCommand
""" """
nlp = pipeline(args.task, args.model) nlp = pipeline(task=args.task, model=args.model, config=args.config, tokenizer=args.tokenizer, device=args.device)
return ServeCommand(nlp, args.host, args.port, args.model, args.graphql) return ServeCommand(nlp, args.host, args.port)
class ServeResult(BaseModel): class ServeModelInfoResult(BaseModel):
"""
Base class for serving result
"""
model: str
class ServeModelInfoResult(ServeResult):
""" """
Expose model information Expose model information
""" """
infos: dict infos: dict
class ServeTokenizeResult(ServeResult): class ServeTokenizeResult(BaseModel):
""" """
Tokenize result model Tokenize result model
""" """
...@@ -43,14 +36,14 @@ class ServeTokenizeResult(ServeResult): ...@@ -43,14 +36,14 @@ class ServeTokenizeResult(ServeResult):
tokens_ids: Optional[List[int]] tokens_ids: Optional[List[int]]
class ServeDeTokenizeResult(ServeResult): class ServeDeTokenizeResult(BaseModel):
""" """
DeTokenize result model DeTokenize result model
""" """
text: str text: str
class ServeForwardResult(ServeResult): class ServeForwardResult(BaseModel):
""" """
Forward result model Forward result model
""" """
...@@ -71,11 +64,12 @@ class ServeCommand(BaseTransformersCLICommand): ...@@ -71,11 +64,12 @@ class ServeCommand(BaseTransformersCLICommand):
serve_parser.add_argument('--device', type=int, default=-1, help='Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)') serve_parser.add_argument('--device', type=int, default=-1, help='Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)')
serve_parser.add_argument('--host', type=str, default='localhost', help='Interface the server will listen on.') serve_parser.add_argument('--host', type=str, default='localhost', help='Interface the server will listen on.')
serve_parser.add_argument('--port', type=int, default=8888, help='Port the serving will listen to.') serve_parser.add_argument('--port', type=int, default=8888, help='Port the serving will listen to.')
serve_parser.add_argument('--model', type=str, required=True, help='Model\'s name or path to stored model to infer from.') serve_parser.add_argument('--model', type=str, required=True, help='Model\'s name or path to stored model.')
serve_parser.add_argument('--graphql', action='store_true', default=False, help='Enable GraphQL endpoints.') serve_parser.add_argument('--config', type=str, help='Model\'s config name or path to stored model.')
serve_parser.add_argument('--tokenizer', type=str, help='Tokenizer name to use.')
serve_parser.set_defaults(func=serve_command_factory) serve_parser.set_defaults(func=serve_command_factory)
def __init__(self, pipeline: Pipeline, host: str, port: int, model: str, graphql: bool): def __init__(self, pipeline: Pipeline, host: str, port: int):
self._logger = getLogger('transformers-cli/serving') self._logger = getLogger('transformers-cli/serving')
self._pipeline = pipeline self._pipeline = pipeline
...@@ -95,7 +89,7 @@ class ServeCommand(BaseTransformersCLICommand): ...@@ -95,7 +89,7 @@ class ServeCommand(BaseTransformersCLICommand):
run(self._app, host=self._host, port=self._port) run(self._app, host=self._host, port=self._port)
def model_info(self): def model_info(self):
return ServeModelInfoResult(model='', infos=vars(self._pipeline.model.config)) return ServeModelInfoResult(infos=vars(self._pipeline.model.config))
def tokenize(self, text_input: str = Body(None, embed=True), return_ids: bool = Body(False, embed=True)): def tokenize(self, text_input: str = Body(None, embed=True), return_ids: bool = Body(False, embed=True)):
""" """
...@@ -108,9 +102,9 @@ class ServeCommand(BaseTransformersCLICommand): ...@@ -108,9 +102,9 @@ class ServeCommand(BaseTransformersCLICommand):
if return_ids: if return_ids:
tokens_ids = self._pipeline.tokenizer.convert_tokens_to_ids(tokens_txt) tokens_ids = self._pipeline.tokenizer.convert_tokens_to_ids(tokens_txt)
return ServeTokenizeResult(model='', tokens=tokens_txt, tokens_ids=tokens_ids) return ServeTokenizeResult(tokens=tokens_txt, tokens_ids=tokens_ids)
else: else:
return ServeTokenizeResult(model='', tokens=tokens_txt) return ServeTokenizeResult(tokens=tokens_txt)
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail={"model": '', "error": str(e)}) raise HTTPException(status_code=500, detail={"model": '', "error": str(e)})
...@@ -139,13 +133,11 @@ class ServeCommand(BaseTransformersCLICommand): ...@@ -139,13 +133,11 @@ class ServeCommand(BaseTransformersCLICommand):
# Check we don't have empty string # Check we don't have empty string
if len(inputs) == 0: if len(inputs) == 0:
return ServeForwardResult(model='', output=[], attention=[]) return ServeForwardResult(output=[], attention=[])
try: try:
# Forward through the model # Forward through the model
output = self._pipeline(inputs) output = self._pipeline(inputs)
return ServeForwardResult( return ServeForwardResult(output=output)
model='', output=output
)
except Exception as e: except Exception as e:
raise HTTPException(500, {"error": str(e)}) raise HTTPException(500, {"error": str(e)})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment