Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3b29322d
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "801894e08c53a80c2b4377ce102a44dccfda8928"
Commit
3b29322d
authored
Dec 19, 2019
by
Morgan Funtowicz
Browse files
Expose all the pipeline argument on serve command.
parent
fc624716
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
23 deletions
+15
-23
transformers/commands/serving.py
transformers/commands/serving.py
+15
-23
No files found.
transformers/commands/serving.py
View file @
3b29322d
...
@@ -17,25 +17,18 @@ def serve_command_factory(args: Namespace):
...
@@ -17,25 +17,18 @@ def serve_command_factory(args: Namespace):
Factory function used to instantiate serving server from provided command line arguments.
Factory function used to instantiate serving server from provided command line arguments.
:return: ServeCommand
:return: ServeCommand
"""
"""
nlp
=
pipeline
(
args
.
task
,
args
.
model
)
nlp
=
pipeline
(
task
=
args
.
task
,
model
=
args
.
model
,
config
=
args
.
config
,
tokenizer
=
args
.
tokenizer
,
device
=
args
.
device
)
return
ServeCommand
(
nlp
,
args
.
host
,
args
.
port
,
args
.
model
,
args
.
graphql
)
return
ServeCommand
(
nlp
,
args
.
host
,
args
.
port
)
class
ServeResult
(
BaseModel
):
class
ServeModelInfoResult
(
BaseModel
):
"""
Base class for serving result
"""
model
:
str
class
ServeModelInfoResult
(
ServeResult
):
"""
"""
Expose model information
Expose model information
"""
"""
infos
:
dict
infos
:
dict
class
ServeTokenizeResult
(
ServeResult
):
class
ServeTokenizeResult
(
BaseModel
):
"""
"""
Tokenize result model
Tokenize result model
"""
"""
...
@@ -43,14 +36,14 @@ class ServeTokenizeResult(ServeResult):
...
@@ -43,14 +36,14 @@ class ServeTokenizeResult(ServeResult):
tokens_ids
:
Optional
[
List
[
int
]]
tokens_ids
:
Optional
[
List
[
int
]]
class
ServeDeTokenizeResult
(
ServeResult
):
class
ServeDeTokenizeResult
(
BaseModel
):
"""
"""
DeTokenize result model
DeTokenize result model
"""
"""
text
:
str
text
:
str
class
ServeForwardResult
(
ServeResult
):
class
ServeForwardResult
(
BaseModel
):
"""
"""
Forward result model
Forward result model
"""
"""
...
@@ -71,11 +64,12 @@ class ServeCommand(BaseTransformersCLICommand):
...
@@ -71,11 +64,12 @@ class ServeCommand(BaseTransformersCLICommand):
serve_parser
.
add_argument
(
'--device'
,
type
=
int
,
default
=-
1
,
help
=
'Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)'
)
serve_parser
.
add_argument
(
'--device'
,
type
=
int
,
default
=-
1
,
help
=
'Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)'
)
serve_parser
.
add_argument
(
'--host'
,
type
=
str
,
default
=
'localhost'
,
help
=
'Interface the server will listen on.'
)
serve_parser
.
add_argument
(
'--host'
,
type
=
str
,
default
=
'localhost'
,
help
=
'Interface the server will listen on.'
)
serve_parser
.
add_argument
(
'--port'
,
type
=
int
,
default
=
8888
,
help
=
'Port the serving will listen to.'
)
serve_parser
.
add_argument
(
'--port'
,
type
=
int
,
default
=
8888
,
help
=
'Port the serving will listen to.'
)
serve_parser
.
add_argument
(
'--model'
,
type
=
str
,
required
=
True
,
help
=
'Model
\'
s name or path to stored model to infer from.'
)
serve_parser
.
add_argument
(
'--model'
,
type
=
str
,
required
=
True
,
help
=
'Model
\'
s name or path to stored model.'
)
serve_parser
.
add_argument
(
'--graphql'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Enable GraphQL endpoints.'
)
serve_parser
.
add_argument
(
'--config'
,
type
=
str
,
help
=
'Model
\'
s config name or path to stored model.'
)
serve_parser
.
add_argument
(
'--tokenizer'
,
type
=
str
,
help
=
'Tokenizer name to use.'
)
serve_parser
.
set_defaults
(
func
=
serve_command_factory
)
serve_parser
.
set_defaults
(
func
=
serve_command_factory
)
def
__init__
(
self
,
pipeline
:
Pipeline
,
host
:
str
,
port
:
int
,
model
:
str
,
graphql
:
bool
):
def
__init__
(
self
,
pipeline
:
Pipeline
,
host
:
str
,
port
:
int
):
self
.
_logger
=
getLogger
(
'transformers-cli/serving'
)
self
.
_logger
=
getLogger
(
'transformers-cli/serving'
)
self
.
_pipeline
=
pipeline
self
.
_pipeline
=
pipeline
...
@@ -95,7 +89,7 @@ class ServeCommand(BaseTransformersCLICommand):
...
@@ -95,7 +89,7 @@ class ServeCommand(BaseTransformersCLICommand):
run
(
self
.
_app
,
host
=
self
.
_host
,
port
=
self
.
_port
)
run
(
self
.
_app
,
host
=
self
.
_host
,
port
=
self
.
_port
)
def
model_info
(
self
):
def
model_info
(
self
):
return
ServeModelInfoResult
(
model
=
''
,
infos
=
vars
(
self
.
_pipeline
.
model
.
config
))
return
ServeModelInfoResult
(
infos
=
vars
(
self
.
_pipeline
.
model
.
config
))
def
tokenize
(
self
,
text_input
:
str
=
Body
(
None
,
embed
=
True
),
return_ids
:
bool
=
Body
(
False
,
embed
=
True
)):
def
tokenize
(
self
,
text_input
:
str
=
Body
(
None
,
embed
=
True
),
return_ids
:
bool
=
Body
(
False
,
embed
=
True
)):
"""
"""
...
@@ -108,9 +102,9 @@ class ServeCommand(BaseTransformersCLICommand):
...
@@ -108,9 +102,9 @@ class ServeCommand(BaseTransformersCLICommand):
if
return_ids
:
if
return_ids
:
tokens_ids
=
self
.
_pipeline
.
tokenizer
.
convert_tokens_to_ids
(
tokens_txt
)
tokens_ids
=
self
.
_pipeline
.
tokenizer
.
convert_tokens_to_ids
(
tokens_txt
)
return
ServeTokenizeResult
(
model
=
''
,
tokens
=
tokens_txt
,
tokens_ids
=
tokens_ids
)
return
ServeTokenizeResult
(
tokens
=
tokens_txt
,
tokens_ids
=
tokens_ids
)
else
:
else
:
return
ServeTokenizeResult
(
model
=
''
,
tokens
=
tokens_txt
)
return
ServeTokenizeResult
(
tokens
=
tokens_txt
)
except
Exception
as
e
:
except
Exception
as
e
:
raise
HTTPException
(
status_code
=
500
,
detail
=
{
"model"
:
''
,
"error"
:
str
(
e
)})
raise
HTTPException
(
status_code
=
500
,
detail
=
{
"model"
:
''
,
"error"
:
str
(
e
)})
...
@@ -139,13 +133,11 @@ class ServeCommand(BaseTransformersCLICommand):
...
@@ -139,13 +133,11 @@ class ServeCommand(BaseTransformersCLICommand):
# Check we don't have empty string
# Check we don't have empty string
if
len
(
inputs
)
==
0
:
if
len
(
inputs
)
==
0
:
return
ServeForwardResult
(
model
=
''
,
output
=
[],
attention
=
[])
return
ServeForwardResult
(
output
=
[],
attention
=
[])
try
:
try
:
# Forward through the model
# Forward through the model
output
=
self
.
_pipeline
(
inputs
)
output
=
self
.
_pipeline
(
inputs
)
return
ServeForwardResult
(
return
ServeForwardResult
(
output
=
output
)
model
=
''
,
output
=
output
)
except
Exception
as
e
:
except
Exception
as
e
:
raise
HTTPException
(
500
,
{
"error"
:
str
(
e
)})
raise
HTTPException
(
500
,
{
"error"
:
str
(
e
)})
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment