Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3b29322d
Commit
3b29322d
authored
Dec 19, 2019
by
Morgan Funtowicz
Browse files
Expose all the pipeline argument on serve command.
parent
fc624716
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
23 deletions
+15
-23
transformers/commands/serving.py
transformers/commands/serving.py
+15
-23
No files found.
transformers/commands/serving.py
View file @
3b29322d
...
@@ -17,25 +17,18 @@ def serve_command_factory(args: Namespace):
...
@@ -17,25 +17,18 @@ def serve_command_factory(args: Namespace):
Factory function used to instantiate serving server from provided command line arguments.
Factory function used to instantiate serving server from provided command line arguments.
:return: ServeCommand
:return: ServeCommand
"""
"""
nlp
=
pipeline
(
args
.
task
,
args
.
model
)
nlp
=
pipeline
(
task
=
args
.
task
,
model
=
args
.
model
,
config
=
args
.
config
,
tokenizer
=
args
.
tokenizer
,
device
=
args
.
device
)
return
ServeCommand
(
nlp
,
args
.
host
,
args
.
port
,
args
.
model
,
args
.
graphql
)
return
ServeCommand
(
nlp
,
args
.
host
,
args
.
port
)
class
ServeResult
(
BaseModel
):
class
ServeModelInfoResult
(
BaseModel
):
"""
Base class for serving result
"""
model
:
str
class
ServeModelInfoResult
(
ServeResult
):
"""
"""
Expose model information
Expose model information
"""
"""
infos
:
dict
infos
:
dict
class
ServeTokenizeResult
(
ServeResult
):
class
ServeTokenizeResult
(
BaseModel
):
"""
"""
Tokenize result model
Tokenize result model
"""
"""
...
@@ -43,14 +36,14 @@ class ServeTokenizeResult(ServeResult):
...
@@ -43,14 +36,14 @@ class ServeTokenizeResult(ServeResult):
tokens_ids
:
Optional
[
List
[
int
]]
tokens_ids
:
Optional
[
List
[
int
]]
class
ServeDeTokenizeResult
(
ServeResult
):
class
ServeDeTokenizeResult
(
BaseModel
):
"""
"""
DeTokenize result model
DeTokenize result model
"""
"""
text
:
str
text
:
str
class
ServeForwardResult
(
ServeResult
):
class
ServeForwardResult
(
BaseModel
):
"""
"""
Forward result model
Forward result model
"""
"""
...
@@ -71,11 +64,12 @@ class ServeCommand(BaseTransformersCLICommand):
...
@@ -71,11 +64,12 @@ class ServeCommand(BaseTransformersCLICommand):
serve_parser
.
add_argument
(
'--device'
,
type
=
int
,
default
=-
1
,
help
=
'Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)'
)
serve_parser
.
add_argument
(
'--device'
,
type
=
int
,
default
=-
1
,
help
=
'Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)'
)
serve_parser
.
add_argument
(
'--host'
,
type
=
str
,
default
=
'localhost'
,
help
=
'Interface the server will listen on.'
)
serve_parser
.
add_argument
(
'--host'
,
type
=
str
,
default
=
'localhost'
,
help
=
'Interface the server will listen on.'
)
serve_parser
.
add_argument
(
'--port'
,
type
=
int
,
default
=
8888
,
help
=
'Port the serving will listen to.'
)
serve_parser
.
add_argument
(
'--port'
,
type
=
int
,
default
=
8888
,
help
=
'Port the serving will listen to.'
)
serve_parser
.
add_argument
(
'--model'
,
type
=
str
,
required
=
True
,
help
=
'Model
\'
s name or path to stored model to infer from.'
)
serve_parser
.
add_argument
(
'--model'
,
type
=
str
,
required
=
True
,
help
=
'Model
\'
s name or path to stored model.'
)
serve_parser
.
add_argument
(
'--graphql'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Enable GraphQL endpoints.'
)
serve_parser
.
add_argument
(
'--config'
,
type
=
str
,
help
=
'Model
\'
s config name or path to stored model.'
)
serve_parser
.
add_argument
(
'--tokenizer'
,
type
=
str
,
help
=
'Tokenizer name to use.'
)
serve_parser
.
set_defaults
(
func
=
serve_command_factory
)
serve_parser
.
set_defaults
(
func
=
serve_command_factory
)
def
__init__
(
self
,
pipeline
:
Pipeline
,
host
:
str
,
port
:
int
,
model
:
str
,
graphql
:
bool
):
def
__init__
(
self
,
pipeline
:
Pipeline
,
host
:
str
,
port
:
int
):
self
.
_logger
=
getLogger
(
'transformers-cli/serving'
)
self
.
_logger
=
getLogger
(
'transformers-cli/serving'
)
self
.
_pipeline
=
pipeline
self
.
_pipeline
=
pipeline
...
@@ -95,7 +89,7 @@ class ServeCommand(BaseTransformersCLICommand):
...
@@ -95,7 +89,7 @@ class ServeCommand(BaseTransformersCLICommand):
run
(
self
.
_app
,
host
=
self
.
_host
,
port
=
self
.
_port
)
run
(
self
.
_app
,
host
=
self
.
_host
,
port
=
self
.
_port
)
def
model_info
(
self
):
def
model_info
(
self
):
return
ServeModelInfoResult
(
model
=
''
,
infos
=
vars
(
self
.
_pipeline
.
model
.
config
))
return
ServeModelInfoResult
(
infos
=
vars
(
self
.
_pipeline
.
model
.
config
))
def
tokenize
(
self
,
text_input
:
str
=
Body
(
None
,
embed
=
True
),
return_ids
:
bool
=
Body
(
False
,
embed
=
True
)):
def
tokenize
(
self
,
text_input
:
str
=
Body
(
None
,
embed
=
True
),
return_ids
:
bool
=
Body
(
False
,
embed
=
True
)):
"""
"""
...
@@ -108,9 +102,9 @@ class ServeCommand(BaseTransformersCLICommand):
...
@@ -108,9 +102,9 @@ class ServeCommand(BaseTransformersCLICommand):
if
return_ids
:
if
return_ids
:
tokens_ids
=
self
.
_pipeline
.
tokenizer
.
convert_tokens_to_ids
(
tokens_txt
)
tokens_ids
=
self
.
_pipeline
.
tokenizer
.
convert_tokens_to_ids
(
tokens_txt
)
return
ServeTokenizeResult
(
model
=
''
,
tokens
=
tokens_txt
,
tokens_ids
=
tokens_ids
)
return
ServeTokenizeResult
(
tokens
=
tokens_txt
,
tokens_ids
=
tokens_ids
)
else
:
else
:
return
ServeTokenizeResult
(
model
=
''
,
tokens
=
tokens_txt
)
return
ServeTokenizeResult
(
tokens
=
tokens_txt
)
except
Exception
as
e
:
except
Exception
as
e
:
raise
HTTPException
(
status_code
=
500
,
detail
=
{
"model"
:
''
,
"error"
:
str
(
e
)})
raise
HTTPException
(
status_code
=
500
,
detail
=
{
"model"
:
''
,
"error"
:
str
(
e
)})
...
@@ -139,13 +133,11 @@ class ServeCommand(BaseTransformersCLICommand):
...
@@ -139,13 +133,11 @@ class ServeCommand(BaseTransformersCLICommand):
# Check we don't have empty string
# Check we don't have empty string
if
len
(
inputs
)
==
0
:
if
len
(
inputs
)
==
0
:
return
ServeForwardResult
(
model
=
''
,
output
=
[],
attention
=
[])
return
ServeForwardResult
(
output
=
[],
attention
=
[])
try
:
try
:
# Forward through the model
# Forward through the model
output
=
self
.
_pipeline
(
inputs
)
output
=
self
.
_pipeline
(
inputs
)
return
ServeForwardResult
(
return
ServeForwardResult
(
output
=
output
)
model
=
''
,
output
=
output
)
except
Exception
as
e
:
except
Exception
as
e
:
raise
HTTPException
(
500
,
{
"error"
:
str
(
e
)})
raise
HTTPException
(
500
,
{
"error"
:
str
(
e
)})
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment