Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3b29322d
Commit
3b29322d
authored
Dec 19, 2019
by
Morgan Funtowicz
Browse files
Expose all the pipeline argument on serve command.
parent
fc624716
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
23 deletions
+15
-23
transformers/commands/serving.py
transformers/commands/serving.py
+15
-23
No files found.
transformers/commands/serving.py
View file @
3b29322d
...
...
@@ -17,25 +17,18 @@ def serve_command_factory(args: Namespace):
Factory function used to instantiate serving server from provided command line arguments.
:return: ServeCommand
"""
nlp
=
pipeline
(
args
.
task
,
args
.
model
)
return
ServeCommand
(
nlp
,
args
.
host
,
args
.
port
,
args
.
model
,
args
.
graphql
)
nlp
=
pipeline
(
task
=
args
.
task
,
model
=
args
.
model
,
config
=
args
.
config
,
tokenizer
=
args
.
tokenizer
,
device
=
args
.
device
)
return
ServeCommand
(
nlp
,
args
.
host
,
args
.
port
)
class
ServeResult
(
BaseModel
):
"""
Base class for serving result
"""
model
:
str
class
ServeModelInfoResult
(
ServeResult
):
class
ServeModelInfoResult
(
BaseModel
):
"""
Expose model information
"""
infos
:
dict
class
ServeTokenizeResult
(
ServeResult
):
class
ServeTokenizeResult
(
BaseModel
):
"""
Tokenize result model
"""
...
...
@@ -43,14 +36,14 @@ class ServeTokenizeResult(ServeResult):
tokens_ids
:
Optional
[
List
[
int
]]
class
ServeDeTokenizeResult
(
ServeResult
):
class
ServeDeTokenizeResult
(
BaseModel
):
"""
DeTokenize result model
"""
text
:
str
class
ServeForwardResult
(
ServeResult
):
class
ServeForwardResult
(
BaseModel
):
"""
Forward result model
"""
...
...
@@ -71,11 +64,12 @@ class ServeCommand(BaseTransformersCLICommand):
serve_parser
.
add_argument
(
'--device'
,
type
=
int
,
default
=-
1
,
help
=
'Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)'
)
serve_parser
.
add_argument
(
'--host'
,
type
=
str
,
default
=
'localhost'
,
help
=
'Interface the server will listen on.'
)
serve_parser
.
add_argument
(
'--port'
,
type
=
int
,
default
=
8888
,
help
=
'Port the serving will listen to.'
)
serve_parser
.
add_argument
(
'--model'
,
type
=
str
,
required
=
True
,
help
=
'Model
\'
s name or path to stored model to infer from.'
)
serve_parser
.
add_argument
(
'--graphql'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Enable GraphQL endpoints.'
)
serve_parser
.
add_argument
(
'--model'
,
type
=
str
,
required
=
True
,
help
=
'Model
\'
s name or path to stored model.'
)
serve_parser
.
add_argument
(
'--config'
,
type
=
str
,
help
=
'Model
\'
s config name or path to stored model.'
)
serve_parser
.
add_argument
(
'--tokenizer'
,
type
=
str
,
help
=
'Tokenizer name to use.'
)
serve_parser
.
set_defaults
(
func
=
serve_command_factory
)
def
__init__
(
self
,
pipeline
:
Pipeline
,
host
:
str
,
port
:
int
,
model
:
str
,
graphql
:
bool
):
def
__init__
(
self
,
pipeline
:
Pipeline
,
host
:
str
,
port
:
int
):
self
.
_logger
=
getLogger
(
'transformers-cli/serving'
)
self
.
_pipeline
=
pipeline
...
...
@@ -95,7 +89,7 @@ class ServeCommand(BaseTransformersCLICommand):
run
(
self
.
_app
,
host
=
self
.
_host
,
port
=
self
.
_port
)
def
model_info
(
self
):
return
ServeModelInfoResult
(
model
=
''
,
infos
=
vars
(
self
.
_pipeline
.
model
.
config
))
return
ServeModelInfoResult
(
infos
=
vars
(
self
.
_pipeline
.
model
.
config
))
def
tokenize
(
self
,
text_input
:
str
=
Body
(
None
,
embed
=
True
),
return_ids
:
bool
=
Body
(
False
,
embed
=
True
)):
"""
...
...
@@ -108,9 +102,9 @@ class ServeCommand(BaseTransformersCLICommand):
if
return_ids
:
tokens_ids
=
self
.
_pipeline
.
tokenizer
.
convert_tokens_to_ids
(
tokens_txt
)
return
ServeTokenizeResult
(
model
=
''
,
tokens
=
tokens_txt
,
tokens_ids
=
tokens_ids
)
return
ServeTokenizeResult
(
tokens
=
tokens_txt
,
tokens_ids
=
tokens_ids
)
else
:
return
ServeTokenizeResult
(
model
=
''
,
tokens
=
tokens_txt
)
return
ServeTokenizeResult
(
tokens
=
tokens_txt
)
except
Exception
as
e
:
raise
HTTPException
(
status_code
=
500
,
detail
=
{
"model"
:
''
,
"error"
:
str
(
e
)})
...
...
@@ -139,13 +133,11 @@ class ServeCommand(BaseTransformersCLICommand):
# Check we don't have empty string
if
len
(
inputs
)
==
0
:
return
ServeForwardResult
(
model
=
''
,
output
=
[],
attention
=
[])
return
ServeForwardResult
(
output
=
[],
attention
=
[])
try
:
# Forward through the model
output
=
self
.
_pipeline
(
inputs
)
return
ServeForwardResult
(
model
=
''
,
output
=
output
)
return
ServeForwardResult
(
output
=
output
)
except
Exception
as
e
:
raise
HTTPException
(
500
,
{
"error"
:
str
(
e
)})
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment