Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a096e2a8
Commit
a096e2a8
authored
Dec 16, 2019
by
Morgan Funtowicz
Browse files
WIP serving through HTTP internally using pipelines.
parent
43a4e1bb
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
46 deletions
+21
-46
transformers/commands/serving.py
transformers/commands/serving.py
+21
-46
No files found.
transformers/commands/serving.py
View file @
a096e2a8
from
argparse
import
ArgumentParser
,
Namespace
from
argparse
import
ArgumentParser
,
Namespace
from
typing
import
List
,
Optional
,
Union
,
Any
from
typing
import
List
,
Optional
,
Union
,
Any
import
torch
from
fastapi
import
FastAPI
,
HTTPException
,
Body
from
fastapi
import
FastAPI
,
HTTPException
,
Body
from
logging
import
getLogger
from
logging
import
getLogger
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
from
uvicorn
import
run
from
uvicorn
import
run
from
transformers
import
AutoModel
,
AutoTokenizer
,
AutoConfig
from
transformers
import
Pipeline
from
transformers.commands
import
BaseTransformersCLICommand
from
transformers.commands
import
BaseTransformersCLICommand
from
transformers.pipelines
import
SUPPORTED_TASKS
,
pipeline
def
serve_command_factory
(
args
:
Namespace
):
def
serve_command_factory
(
args
:
Namespace
):
...
@@ -17,7 +17,8 @@ def serve_command_factory(args: Namespace):
...
@@ -17,7 +17,8 @@ def serve_command_factory(args: Namespace):
Factory function used to instantiate serving server from provided command line arguments.
Factory function used to instantiate serving server from provided command line arguments.
:return: ServeCommand
:return: ServeCommand
"""
"""
return
ServeCommand
(
args
.
host
,
args
.
port
,
args
.
model
,
args
.
graphql
)
nlp
=
pipeline
(
args
.
task
,
args
.
model
)
return
ServeCommand
(
nlp
,
args
.
host
,
args
.
port
,
args
.
model
,
args
.
graphql
)
class
ServeResult
(
BaseModel
):
class
ServeResult
(
BaseModel
):
...
@@ -53,8 +54,6 @@ class ServeForwardResult(ServeResult):
...
@@ -53,8 +54,6 @@ class ServeForwardResult(ServeResult):
"""
"""
Forward result model
Forward result model
"""
"""
tokens
:
List
[
str
]
tokens_ids
:
List
[
int
]
output
:
Any
output
:
Any
...
@@ -68,19 +67,18 @@ class ServeCommand(BaseTransformersCLICommand):
...
@@ -68,19 +67,18 @@ class ServeCommand(BaseTransformersCLICommand):
:return:
:return:
"""
"""
serve_parser
=
parser
.
add_parser
(
'serve'
,
help
=
'CLI tool to run inference requests through REST and GraphQL endpoints.'
)
serve_parser
=
parser
.
add_parser
(
'serve'
,
help
=
'CLI tool to run inference requests through REST and GraphQL endpoints.'
)
serve_parser
.
add_argument
(
'--task'
,
type
=
str
,
choices
=
SUPPORTED_TASKS
.
keys
(),
help
=
'The task to run the pipeline on'
)
serve_parser
.
add_argument
(
'--device'
,
type
=
int
,
default
=-
1
,
help
=
'Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)'
)
serve_parser
.
add_argument
(
'--host'
,
type
=
str
,
default
=
'localhost'
,
help
=
'Interface the server will listen on.'
)
serve_parser
.
add_argument
(
'--host'
,
type
=
str
,
default
=
'localhost'
,
help
=
'Interface the server will listen on.'
)
serve_parser
.
add_argument
(
'--port'
,
type
=
int
,
default
=
8888
,
help
=
'Port the serving will listen to.'
)
serve_parser
.
add_argument
(
'--port'
,
type
=
int
,
default
=
8888
,
help
=
'Port the serving will listen to.'
)
serve_parser
.
add_argument
(
'--model'
,
type
=
str
,
required
=
True
,
help
=
'Model
\'
s name or path to stored model to infer from.'
)
serve_parser
.
add_argument
(
'--model'
,
type
=
str
,
required
=
True
,
help
=
'Model
\'
s name or path to stored model to infer from.'
)
serve_parser
.
add_argument
(
'--graphql'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Enable GraphQL endpoints.'
)
serve_parser
.
add_argument
(
'--graphql'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Enable GraphQL endpoints.'
)
serve_parser
.
set_defaults
(
func
=
serve_command_factory
)
serve_parser
.
set_defaults
(
func
=
serve_command_factory
)
def
__init__
(
self
,
host
:
str
,
port
:
int
,
model
:
str
,
graphql
:
bool
):
def
__init__
(
self
,
pipeline
:
Pipeline
,
host
:
str
,
port
:
int
,
model
:
str
,
graphql
:
bool
):
self
.
_logger
=
getLogger
(
'transformers-cli/serving'
)
self
.
_logger
=
getLogger
(
'transformers-cli/serving'
)
self
.
_logger
.
info
(
'Loading model {}'
.
format
(
model
))
self
.
_pipeline
=
pipeline
self
.
_model_name
=
model
self
.
_model
=
AutoModel
.
from_pretrained
(
model
)
self
.
_tokenizer
=
AutoTokenizer
.
from_pretrained
(
model
)
self
.
_logger
.
info
(
'Serving model over {}:{}'
.
format
(
host
,
port
))
self
.
_logger
.
info
(
'Serving model over {}:{}'
.
format
(
host
,
port
))
self
.
_host
=
host
self
.
_host
=
host
...
@@ -97,7 +95,7 @@ class ServeCommand(BaseTransformersCLICommand):
...
@@ -97,7 +95,7 @@ class ServeCommand(BaseTransformersCLICommand):
run
(
self
.
_app
,
host
=
self
.
_host
,
port
=
self
.
_port
)
run
(
self
.
_app
,
host
=
self
.
_host
,
port
=
self
.
_port
)
def
model_info
(
self
):
def
model_info
(
self
):
return
ServeModelInfoResult
(
model
=
self
.
_model_name
,
infos
=
vars
(
self
.
_model
.
config
))
return
ServeModelInfoResult
(
model
=
''
,
infos
=
vars
(
self
.
_
pipeline
.
model
.
config
))
def
tokenize
(
self
,
text_input
:
str
=
Body
(
None
,
embed
=
True
),
return_ids
:
bool
=
Body
(
False
,
embed
=
True
)):
def
tokenize
(
self
,
text_input
:
str
=
Body
(
None
,
embed
=
True
),
return_ids
:
bool
=
Body
(
False
,
embed
=
True
)):
"""
"""
...
@@ -106,16 +104,16 @@ class ServeCommand(BaseTransformersCLICommand):
...
@@ -106,16 +104,16 @@ class ServeCommand(BaseTransformersCLICommand):
- **return_ids**: Boolean flags indicating if the tokens have to be converted to their integer mapping.
- **return_ids**: Boolean flags indicating if the tokens have to be converted to their integer mapping.
"""
"""
try
:
try
:
tokens_txt
=
self
.
_tokenizer
.
tokenize
(
text_input
)
tokens_txt
=
self
.
_
pipeline
.
tokenizer
.
tokenize
(
text_input
)
if
return_ids
:
if
return_ids
:
tokens_ids
=
self
.
_tokenizer
.
convert_tokens_to_ids
(
tokens_txt
)
tokens_ids
=
self
.
_
pipeline
.
tokenizer
.
convert_tokens_to_ids
(
tokens_txt
)
return
ServeTokenizeResult
(
model
=
self
.
_model_name
,
tokens
=
tokens_txt
,
tokens_ids
=
tokens_ids
)
return
ServeTokenizeResult
(
model
=
''
,
tokens
=
tokens_txt
,
tokens_ids
=
tokens_ids
)
else
:
else
:
return
ServeTokenizeResult
(
model
=
self
.
_model_name
,
tokens
=
tokens_txt
)
return
ServeTokenizeResult
(
model
=
''
,
tokens
=
tokens_txt
)
except
Exception
as
e
:
except
Exception
as
e
:
raise
HTTPException
(
status_code
=
500
,
detail
=
{
"model"
:
self
.
_model_name
,
"error"
:
str
(
e
)})
raise
HTTPException
(
status_code
=
500
,
detail
=
{
"model"
:
''
,
"error"
:
str
(
e
)})
def
detokenize
(
self
,
tokens_ids
:
List
[
int
]
=
Body
(
None
,
embed
=
True
),
def
detokenize
(
self
,
tokens_ids
:
List
[
int
]
=
Body
(
None
,
embed
=
True
),
skip_special_tokens
:
bool
=
Body
(
False
,
embed
=
True
),
skip_special_tokens
:
bool
=
Body
(
False
,
embed
=
True
),
...
@@ -127,14 +125,12 @@ class ServeCommand(BaseTransformersCLICommand):
...
@@ -127,14 +125,12 @@ class ServeCommand(BaseTransformersCLICommand):
- **cleanup_tokenization_spaces**: Flag indicating to remove all leading/trailing spaces and intermediate ones.
- **cleanup_tokenization_spaces**: Flag indicating to remove all leading/trailing spaces and intermediate ones.
"""
"""
try
:
try
:
decoded_str
=
self
.
_tokenizer
.
decode
(
tokens_ids
,
skip_special_tokens
,
cleanup_tokenization_spaces
)
decoded_str
=
self
.
_
pipeline
.
tokenizer
.
decode
(
tokens_ids
,
skip_special_tokens
,
cleanup_tokenization_spaces
)
return
ServeDeTokenizeResult
(
model
=
self
.
_model_name
,
text
=
decoded_str
)
return
ServeDeTokenizeResult
(
model
=
''
,
text
=
decoded_str
)
except
Exception
as
e
:
except
Exception
as
e
:
raise
HTTPException
(
status_code
=
500
,
detail
=
{
"model"
:
self
.
_model_name
,
"error"
:
str
(
e
)})
raise
HTTPException
(
status_code
=
500
,
detail
=
{
"model"
:
''
,
"error"
:
str
(
e
)})
def
forward
(
self
,
inputs
:
Union
[
str
,
List
[
str
],
List
[
int
]]
=
Body
(
None
,
embed
=
True
),
def
forward
(
self
,
inputs
:
Union
[
str
,
dict
,
List
[
str
],
List
[
int
],
List
[
dict
]]
=
Body
(
None
,
embed
=
True
)):
attention_mask
:
Optional
[
List
[
int
]]
=
Body
(
None
,
embed
=
True
),
tokens_type_ids
:
Optional
[
List
[
int
]]
=
Body
(
None
,
embed
=
True
)):
"""
"""
**inputs**:
**inputs**:
**attention_mask**:
**attention_mask**:
...
@@ -143,34 +139,13 @@ class ServeCommand(BaseTransformersCLICommand):
...
@@ -143,34 +139,13 @@ class ServeCommand(BaseTransformersCLICommand):
# Check we don't have empty string
# Check we don't have empty string
if
len
(
inputs
)
==
0
:
if
len
(
inputs
)
==
0
:
return
ServeForwardResult
(
model
=
self
.
_model_name
,
output
=
[],
attention
=
[])
return
ServeForwardResult
(
model
=
''
,
output
=
[],
attention
=
[])
if
isinstance
(
inputs
,
str
):
inputs_tokens
=
self
.
_tokenizer
.
tokenize
(
inputs
)
inputs_ids
=
self
.
_tokenizer
.
convert_tokens_to_ids
(
inputs_tokens
)
elif
isinstance
(
inputs
,
List
):
if
isinstance
(
inputs
[
0
],
str
):
inputs_tokens
=
inputs
inputs_ids
=
self
.
_tokenizer
.
convert_tokens_to_ids
(
inputs_tokens
)
elif
isinstance
(
inputs
[
0
],
int
):
inputs_tokens
=
[]
inputs_ids
=
inputs
else
:
error_msg
=
"inputs should be string, [str] of [int] (got {})"
.
format
(
type
(
inputs
[
0
]))
raise
HTTPException
(
423
,
detail
=
{
"error"
:
error_msg
})
else
:
error_msg
=
"inputs should be string, [str] of [int] (got {})"
.
format
(
type
(
inputs
))
raise
HTTPException
(
423
,
detail
=
{
"error"
:
error_msg
})
try
:
try
:
# Forward through the model
# Forward through the model
t_input_ids
=
torch
.
tensor
(
inputs_ids
).
unsqueeze
(
0
)
output
=
self
.
_pipeline
(
inputs
)
output
=
self
.
_model
(
t_input_ids
,
attention_mask
,
tokens_type_ids
)
return
ServeForwardResult
(
return
ServeForwardResult
(
model
=
self
.
_model_name
,
tokens
=
inputs_tokens
,
model
=
''
,
output
=
output
tokens_ids
=
inputs_ids
,
output
=
output
[
0
].
tolist
()
)
)
except
Exception
as
e
:
except
Exception
as
e
:
raise
HTTPException
(
500
,
{
"error"
:
str
(
e
)})
raise
HTTPException
(
500
,
{
"error"
:
str
(
e
)})
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment