Commit 3b29322d authored by Morgan Funtowicz's avatar Morgan Funtowicz
Browse files

Expose all the pipeline argument on serve command.

parent fc624716
......@@ -17,25 +17,18 @@ def serve_command_factory(args: Namespace):
Factory function used to instantiate serving server from provided command line arguments.
:return: ServeCommand
"""
nlp = pipeline(args.task, args.model)
return ServeCommand(nlp, args.host, args.port, args.model, args.graphql)
nlp = pipeline(task=args.task, model=args.model, config=args.config, tokenizer=args.tokenizer, device=args.device)
return ServeCommand(nlp, args.host, args.port)
class ServeResult(BaseModel):
"""
Base class for serving result
"""
model: str
class ServeModelInfoResult(ServeResult):
class ServeModelInfoResult(BaseModel):
"""
Expose model information
"""
infos: dict
class ServeTokenizeResult(ServeResult):
class ServeTokenizeResult(BaseModel):
"""
Tokenize result model
"""
......@@ -43,14 +36,14 @@ class ServeTokenizeResult(ServeResult):
tokens_ids: Optional[List[int]]
class ServeDeTokenizeResult(ServeResult):
class ServeDeTokenizeResult(BaseModel):
"""
DeTokenize result model
"""
text: str
class ServeForwardResult(ServeResult):
class ServeForwardResult(BaseModel):
"""
Forward result model
"""
......@@ -71,11 +64,12 @@ class ServeCommand(BaseTransformersCLICommand):
serve_parser.add_argument('--device', type=int, default=-1, help='Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)')
serve_parser.add_argument('--host', type=str, default='localhost', help='Interface the server will listen on.')
serve_parser.add_argument('--port', type=int, default=8888, help='Port the serving will listen to.')
serve_parser.add_argument('--model', type=str, required=True, help='Model\'s name or path to stored model to infer from.')
serve_parser.add_argument('--graphql', action='store_true', default=False, help='Enable GraphQL endpoints.')
serve_parser.add_argument('--model', type=str, required=True, help='Model\'s name or path to stored model.')
serve_parser.add_argument('--config', type=str, help='Model\'s config name or path to stored model.')
serve_parser.add_argument('--tokenizer', type=str, help='Tokenizer name to use.')
serve_parser.set_defaults(func=serve_command_factory)
def __init__(self, pipeline: Pipeline, host: str, port: int, model: str, graphql: bool):
def __init__(self, pipeline: Pipeline, host: str, port: int):
self._logger = getLogger('transformers-cli/serving')
self._pipeline = pipeline
......@@ -95,7 +89,7 @@ class ServeCommand(BaseTransformersCLICommand):
run(self._app, host=self._host, port=self._port)
def model_info(self):
return ServeModelInfoResult(model='', infos=vars(self._pipeline.model.config))
return ServeModelInfoResult(infos=vars(self._pipeline.model.config))
def tokenize(self, text_input: str = Body(None, embed=True), return_ids: bool = Body(False, embed=True)):
"""
......@@ -108,9 +102,9 @@ class ServeCommand(BaseTransformersCLICommand):
if return_ids:
tokens_ids = self._pipeline.tokenizer.convert_tokens_to_ids(tokens_txt)
return ServeTokenizeResult(model='', tokens=tokens_txt, tokens_ids=tokens_ids)
return ServeTokenizeResult(tokens=tokens_txt, tokens_ids=tokens_ids)
else:
return ServeTokenizeResult(model='', tokens=tokens_txt)
return ServeTokenizeResult(tokens=tokens_txt)
except Exception as e:
raise HTTPException(status_code=500, detail={"model": '', "error": str(e)})
......@@ -139,13 +133,11 @@ class ServeCommand(BaseTransformersCLICommand):
# Check we don't have empty string
if len(inputs) == 0:
return ServeForwardResult(model='', output=[], attention=[])
return ServeForwardResult(output=[], attention=[])
try:
# Forward through the model
output = self._pipeline(inputs)
return ServeForwardResult(
model='', output=output
)
return ServeForwardResult(output=output)
except Exception as e:
raise HTTPException(500, {"error": str(e)})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment