Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
908cd5ea
Commit
908cd5ea
authored
Jan 10, 2020
by
Morgan Funtowicz
Committed by
Lysandre Debut
Jan 20, 2020
Browse files
Make forward asynchrone to avoid long computation timing out.
Signed-off-by:
Morgan Funtowicz
<
morgan@huggingface.co
>
parent
6e6c8c52
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
52 additions
and
19 deletions
+52
-19
src/transformers/commands/serving.py
src/transformers/commands/serving.py
+45
-15
src/transformers/file_utils.py
src/transformers/file_utils.py
+7
-4
No files found.
src/transformers/commands/serving.py
View file @
908cd5ea
import
logging
from
argparse
import
ArgumentParser
,
Namespace
from
typing
import
Any
,
List
,
Optional
,
Union
from
typing
import
Any
,
List
,
Optional
from
starlette.responses
import
JSONResponse
from
transformers
import
Pipeline
from
transformers.commands
import
BaseTransformersCLICommand
...
...
@@ -10,6 +12,7 @@ from transformers.pipelines import SUPPORTED_TASKS, pipeline
try
:
from
uvicorn
import
run
from
fastapi
import
FastAPI
,
HTTPException
,
Body
from
fastapi.routing
import
APIRoute
from
pydantic
import
BaseModel
_serve_dependancies_installed
=
True
...
...
@@ -37,7 +40,7 @@ def serve_command_factory(args: Namespace):
tokenizer
=
args
.
tokenizer
,
device
=
args
.
device
,
)
return
ServeCommand
(
nlp
,
args
.
host
,
args
.
port
)
return
ServeCommand
(
nlp
,
args
.
host
,
args
.
port
,
args
.
workers
)
class
ServeModelInfoResult
(
BaseModel
):
...
...
@@ -89,6 +92,7 @@ class ServeCommand(BaseTransformersCLICommand):
)
serve_parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"localhost"
,
help
=
"Interface the server will listen on."
)
serve_parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8888
,
help
=
"Port the serving will listen to."
)
serve_parser
.
add_argument
(
"--workers"
,
type
=
int
,
default
=
1
,
help
=
"Number of http workers"
)
serve_parser
.
add_argument
(
"--model"
,
type
=
str
,
help
=
"Model's name or path to stored model."
)
serve_parser
.
add_argument
(
"--config"
,
type
=
str
,
help
=
"Model's config name or path to stored model."
)
serve_parser
.
add_argument
(
"--tokenizer"
,
type
=
str
,
help
=
"Tokenizer name to use."
)
...
...
@@ -100,12 +104,14 @@ class ServeCommand(BaseTransformersCLICommand):
)
serve_parser
.
set_defaults
(
func
=
serve_command_factory
)
def
__init__
(
self
,
pipeline
:
Pipeline
,
host
:
str
,
port
:
int
):
def
__init__
(
self
,
pipeline
:
Pipeline
,
host
:
str
,
port
:
int
,
workers
:
int
):
self
.
_pipeline
=
pipeline
self
.
_host
=
host
self
.
_port
=
port
self
.
host
=
host
self
.
port
=
port
self
.
workers
=
workers
if
not
_serve_dependancies_installed
:
raise
RuntimeError
(
"Using serve command requires FastAPI and unicorn. "
...
...
@@ -114,18 +120,42 @@ class ServeCommand(BaseTransformersCLICommand):
)
else
:
logger
.
info
(
"Serving model over {}:{}"
.
format
(
host
,
port
))
self
.
_app
=
FastAPI
()
# Register routes
self
.
_app
.
add_api_route
(
"/"
,
self
.
model_info
,
response_model
=
ServeModelInfoResult
,
methods
=
[
"GET"
])
self
.
_app
.
add_api_route
(
"/tokenize"
,
self
.
tokenize
,
response_model
=
ServeTokenizeResult
,
methods
=
[
"POST"
])
self
.
_app
.
add_api_route
(
"/detokenize"
,
self
.
detokenize
,
response_model
=
ServeDeTokenizeResult
,
methods
=
[
"POST"
]
self
.
_app
=
FastAPI
(
routes
=
[
APIRoute
(
"/"
,
self
.
model_info
,
response_model
=
ServeModelInfoResult
,
response_class
=
JSONResponse
,
methods
=
[
"GET"
],
),
APIRoute
(
"/tokenize"
,
self
.
tokenize
,
response_model
=
ServeTokenizeResult
,
response_class
=
JSONResponse
,
methods
=
[
"POST"
],
),
APIRoute
(
"/detokenize"
,
self
.
detokenize
,
response_model
=
ServeDeTokenizeResult
,
response_class
=
JSONResponse
,
methods
=
[
"POST"
],
),
APIRoute
(
"/forward"
,
self
.
forward
,
response_model
=
ServeForwardResult
,
response_class
=
JSONResponse
,
methods
=
[
"POST"
],
),
],
timeout
=
600
,
)
self
.
_app
.
add_api_route
(
"/forward"
,
self
.
forward
,
response_model
=
ServeForwardResult
,
methods
=
[
"POST"
])
def
run
(
self
):
run
(
self
.
_app
,
host
=
self
.
_
host
,
port
=
self
.
_
port
)
run
(
self
.
_app
,
host
=
self
.
host
,
port
=
self
.
port
,
workers
=
self
.
workers
)
def
model_info
(
self
):
return
ServeModelInfoResult
(
infos
=
vars
(
self
.
_pipeline
.
model
.
config
))
...
...
@@ -166,7 +196,7 @@ class ServeCommand(BaseTransformersCLICommand):
except
Exception
as
e
:
raise
HTTPException
(
status_code
=
500
,
detail
=
{
"model"
:
""
,
"error"
:
str
(
e
)})
def
forward
(
self
,
inputs
:
Union
[
str
,
dict
,
List
[
str
],
List
[
int
],
List
[
dict
]]
=
Body
(
None
,
embed
=
True
)):
async
def
forward
(
self
,
inputs
=
Body
(
None
,
embed
=
True
)):
"""
**inputs**:
**attention_mask**:
...
...
src/transformers/file_utils.py
View file @
908cd5ea
...
...
@@ -28,8 +28,9 @@ from . import __version__
logger
=
logging
.
getLogger
(
__name__
)
# pylint: disable=invalid-name
try
:
if
os
.
environ
.
get
(
"USE_TORCH"
,
'AUTO'
).
upper
()
in
(
"1"
,
"ON"
,
"YES"
,
"AUTO"
)
and
\
os
.
environ
.
get
(
"USE_TF"
,
'AUTO'
).
upper
()
not
in
(
"1"
,
"ON"
,
"YES"
):
USE_TF
=
os
.
environ
.
get
(
"USE_TF"
,
"AUTO"
).
upper
()
USE_TORCH
=
os
.
environ
.
get
(
"USE_TORCH"
,
"AUTO"
).
upper
()
if
USE_TORCH
in
(
"1"
,
"ON"
,
"YES"
,
"AUTO"
)
and
USE_TF
not
in
(
"1"
,
"ON"
,
"YES"
):
import
torch
_torch_available
=
True
# pylint: disable=invalid-name
...
...
@@ -41,8 +42,10 @@ except ImportError:
_torch_available
=
False
# pylint: disable=invalid-name
try
:
if
os
.
environ
.
get
(
"USE_TF"
,
'AUTO'
).
upper
()
in
(
"1"
,
"ON"
,
"YES"
,
"AUTO"
)
and
\
os
.
environ
.
get
(
"USE_TORCH"
,
'AUTO'
).
upper
()
not
in
(
"1"
,
"ON"
,
"YES"
):
USE_TF
=
os
.
environ
.
get
(
"USE_TF"
,
"AUTO"
).
upper
()
USE_TORCH
=
os
.
environ
.
get
(
"USE_TORCH"
,
"AUTO"
).
upper
()
if
USE_TF
in
(
"1"
,
"ON"
,
"YES"
,
"AUTO"
)
and
USE_TORCH
not
in
(
"1"
,
"ON"
,
"YES"
):
import
tensorflow
as
tf
assert
hasattr
(
tf
,
"__version__"
)
and
int
(
tf
.
__version__
[
0
])
>=
2
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment