Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
908cd5ea
Commit
908cd5ea
authored
Jan 10, 2020
by
Morgan Funtowicz
Committed by
Lysandre Debut
Jan 20, 2020
Browse files
Make forward asynchrone to avoid long computation timing out.
Signed-off-by:
Morgan Funtowicz
<
morgan@huggingface.co
>
parent
6e6c8c52
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
52 additions
and
19 deletions
+52
-19
src/transformers/commands/serving.py
src/transformers/commands/serving.py
+45
-15
src/transformers/file_utils.py
src/transformers/file_utils.py
+7
-4
No files found.
src/transformers/commands/serving.py
View file @
908cd5ea
import
logging
import
logging
from
argparse
import
ArgumentParser
,
Namespace
from
argparse
import
ArgumentParser
,
Namespace
from
typing
import
Any
,
List
,
Optional
,
Union
from
typing
import
Any
,
List
,
Optional
from
starlette.responses
import
JSONResponse
from
transformers
import
Pipeline
from
transformers
import
Pipeline
from
transformers.commands
import
BaseTransformersCLICommand
from
transformers.commands
import
BaseTransformersCLICommand
...
@@ -10,6 +12,7 @@ from transformers.pipelines import SUPPORTED_TASKS, pipeline
...
@@ -10,6 +12,7 @@ from transformers.pipelines import SUPPORTED_TASKS, pipeline
try
:
try
:
from
uvicorn
import
run
from
uvicorn
import
run
from
fastapi
import
FastAPI
,
HTTPException
,
Body
from
fastapi
import
FastAPI
,
HTTPException
,
Body
from
fastapi.routing
import
APIRoute
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
_serve_dependancies_installed
=
True
_serve_dependancies_installed
=
True
...
@@ -37,7 +40,7 @@ def serve_command_factory(args: Namespace):
...
@@ -37,7 +40,7 @@ def serve_command_factory(args: Namespace):
tokenizer
=
args
.
tokenizer
,
tokenizer
=
args
.
tokenizer
,
device
=
args
.
device
,
device
=
args
.
device
,
)
)
return
ServeCommand
(
nlp
,
args
.
host
,
args
.
port
)
return
ServeCommand
(
nlp
,
args
.
host
,
args
.
port
,
args
.
workers
)
class
ServeModelInfoResult
(
BaseModel
):
class
ServeModelInfoResult
(
BaseModel
):
...
@@ -89,6 +92,7 @@ class ServeCommand(BaseTransformersCLICommand):
...
@@ -89,6 +92,7 @@ class ServeCommand(BaseTransformersCLICommand):
)
)
serve_parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"localhost"
,
help
=
"Interface the server will listen on."
)
serve_parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"localhost"
,
help
=
"Interface the server will listen on."
)
serve_parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8888
,
help
=
"Port the serving will listen to."
)
serve_parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8888
,
help
=
"Port the serving will listen to."
)
serve_parser
.
add_argument
(
"--workers"
,
type
=
int
,
default
=
1
,
help
=
"Number of http workers"
)
serve_parser
.
add_argument
(
"--model"
,
type
=
str
,
help
=
"Model's name or path to stored model."
)
serve_parser
.
add_argument
(
"--model"
,
type
=
str
,
help
=
"Model's name or path to stored model."
)
serve_parser
.
add_argument
(
"--config"
,
type
=
str
,
help
=
"Model's config name or path to stored model."
)
serve_parser
.
add_argument
(
"--config"
,
type
=
str
,
help
=
"Model's config name or path to stored model."
)
serve_parser
.
add_argument
(
"--tokenizer"
,
type
=
str
,
help
=
"Tokenizer name to use."
)
serve_parser
.
add_argument
(
"--tokenizer"
,
type
=
str
,
help
=
"Tokenizer name to use."
)
...
@@ -100,12 +104,14 @@ class ServeCommand(BaseTransformersCLICommand):
...
@@ -100,12 +104,14 @@ class ServeCommand(BaseTransformersCLICommand):
)
)
serve_parser
.
set_defaults
(
func
=
serve_command_factory
)
serve_parser
.
set_defaults
(
func
=
serve_command_factory
)
def
__init__
(
self
,
pipeline
:
Pipeline
,
host
:
str
,
port
:
int
):
def
__init__
(
self
,
pipeline
:
Pipeline
,
host
:
str
,
port
:
int
,
workers
:
int
):
self
.
_pipeline
=
pipeline
self
.
_pipeline
=
pipeline
self
.
_host
=
host
self
.
host
=
host
self
.
_port
=
port
self
.
port
=
port
self
.
workers
=
workers
if
not
_serve_dependancies_installed
:
if
not
_serve_dependancies_installed
:
raise
RuntimeError
(
raise
RuntimeError
(
"Using serve command requires FastAPI and unicorn. "
"Using serve command requires FastAPI and unicorn. "
...
@@ -114,18 +120,42 @@ class ServeCommand(BaseTransformersCLICommand):
...
@@ -114,18 +120,42 @@ class ServeCommand(BaseTransformersCLICommand):
)
)
else
:
else
:
logger
.
info
(
"Serving model over {}:{}"
.
format
(
host
,
port
))
logger
.
info
(
"Serving model over {}:{}"
.
format
(
host
,
port
))
self
.
_app
=
FastAPI
()
self
.
_app
=
FastAPI
(
routes
=
[
# Register routes
APIRoute
(
self
.
_app
.
add_api_route
(
"/"
,
self
.
model_info
,
response_model
=
ServeModelInfoResult
,
methods
=
[
"GET"
])
"/"
,
self
.
_app
.
add_api_route
(
"/tokenize"
,
self
.
tokenize
,
response_model
=
ServeTokenizeResult
,
methods
=
[
"POST"
])
self
.
model_info
,
self
.
_app
.
add_api_route
(
response_model
=
ServeModelInfoResult
,
"/detokenize"
,
self
.
detokenize
,
response_model
=
ServeDeTokenizeResult
,
methods
=
[
"POST"
]
response_class
=
JSONResponse
,
methods
=
[
"GET"
],
),
APIRoute
(
"/tokenize"
,
self
.
tokenize
,
response_model
=
ServeTokenizeResult
,
response_class
=
JSONResponse
,
methods
=
[
"POST"
],
),
APIRoute
(
"/detokenize"
,
self
.
detokenize
,
response_model
=
ServeDeTokenizeResult
,
response_class
=
JSONResponse
,
methods
=
[
"POST"
],
),
APIRoute
(
"/forward"
,
self
.
forward
,
response_model
=
ServeForwardResult
,
response_class
=
JSONResponse
,
methods
=
[
"POST"
],
),
],
timeout
=
600
,
)
)
self
.
_app
.
add_api_route
(
"/forward"
,
self
.
forward
,
response_model
=
ServeForwardResult
,
methods
=
[
"POST"
])
def
run
(
self
):
def
run
(
self
):
run
(
self
.
_app
,
host
=
self
.
_
host
,
port
=
self
.
_
port
)
run
(
self
.
_app
,
host
=
self
.
host
,
port
=
self
.
port
,
workers
=
self
.
workers
)
def
model_info
(
self
):
def
model_info
(
self
):
return
ServeModelInfoResult
(
infos
=
vars
(
self
.
_pipeline
.
model
.
config
))
return
ServeModelInfoResult
(
infos
=
vars
(
self
.
_pipeline
.
model
.
config
))
...
@@ -166,7 +196,7 @@ class ServeCommand(BaseTransformersCLICommand):
...
@@ -166,7 +196,7 @@ class ServeCommand(BaseTransformersCLICommand):
except
Exception
as
e
:
except
Exception
as
e
:
raise
HTTPException
(
status_code
=
500
,
detail
=
{
"model"
:
""
,
"error"
:
str
(
e
)})
raise
HTTPException
(
status_code
=
500
,
detail
=
{
"model"
:
""
,
"error"
:
str
(
e
)})
def
forward
(
self
,
inputs
:
Union
[
str
,
dict
,
List
[
str
],
List
[
int
],
List
[
dict
]]
=
Body
(
None
,
embed
=
True
)):
async
def
forward
(
self
,
inputs
=
Body
(
None
,
embed
=
True
)):
"""
"""
**inputs**:
**inputs**:
**attention_mask**:
**attention_mask**:
...
...
src/transformers/file_utils.py
View file @
908cd5ea
...
@@ -28,8 +28,9 @@ from . import __version__
...
@@ -28,8 +28,9 @@ from . import __version__
logger
=
logging
.
getLogger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
getLogger
(
__name__
)
# pylint: disable=invalid-name
try
:
try
:
if
os
.
environ
.
get
(
"USE_TORCH"
,
'AUTO'
).
upper
()
in
(
"1"
,
"ON"
,
"YES"
,
"AUTO"
)
and
\
USE_TF
=
os
.
environ
.
get
(
"USE_TF"
,
"AUTO"
).
upper
()
os
.
environ
.
get
(
"USE_TF"
,
'AUTO'
).
upper
()
not
in
(
"1"
,
"ON"
,
"YES"
):
USE_TORCH
=
os
.
environ
.
get
(
"USE_TORCH"
,
"AUTO"
).
upper
()
if
USE_TORCH
in
(
"1"
,
"ON"
,
"YES"
,
"AUTO"
)
and
USE_TF
not
in
(
"1"
,
"ON"
,
"YES"
):
import
torch
import
torch
_torch_available
=
True
# pylint: disable=invalid-name
_torch_available
=
True
# pylint: disable=invalid-name
...
@@ -41,8 +42,10 @@ except ImportError:
...
@@ -41,8 +42,10 @@ except ImportError:
_torch_available
=
False
# pylint: disable=invalid-name
_torch_available
=
False
# pylint: disable=invalid-name
try
:
try
:
if
os
.
environ
.
get
(
"USE_TF"
,
'AUTO'
).
upper
()
in
(
"1"
,
"ON"
,
"YES"
,
"AUTO"
)
and
\
USE_TF
=
os
.
environ
.
get
(
"USE_TF"
,
"AUTO"
).
upper
()
os
.
environ
.
get
(
"USE_TORCH"
,
'AUTO'
).
upper
()
not
in
(
"1"
,
"ON"
,
"YES"
):
USE_TORCH
=
os
.
environ
.
get
(
"USE_TORCH"
,
"AUTO"
).
upper
()
if
USE_TF
in
(
"1"
,
"ON"
,
"YES"
,
"AUTO"
)
and
USE_TORCH
not
in
(
"1"
,
"ON"
,
"YES"
):
import
tensorflow
as
tf
import
tensorflow
as
tf
assert
hasattr
(
tf
,
"__version__"
)
and
int
(
tf
.
__version__
[
0
])
>=
2
assert
hasattr
(
tf
,
"__version__"
)
and
int
(
tf
.
__version__
[
0
])
>=
2
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment