Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c17e7cde
Unverified
Commit
c17e7cde
authored
Dec 13, 2021
by
Suzen Fylke
Committed by
GitHub
Dec 13, 2021
Browse files
Add ability to get a list of supported pipeline tasks (#14732)
parent
3d66146a
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
15 additions
and
10 deletions
+15
-10
src/transformers/commands/run.py
src/transformers/commands/run.py
+2
-4
src/transformers/commands/serving.py
src/transformers/commands/serving.py
+2
-2
src/transformers/pipelines/__init__.py
src/transformers/pipelines/__init__.py
+11
-4
No files found.
src/transformers/commands/run.py
View file @
c17e7cde
...
...
@@ -14,7 +14,7 @@
from
argparse
import
ArgumentParser
from
..pipelines
import
SUPPORTED_TASKS
,
TASK_ALIASES
,
Pipeline
,
PipelineDataFormat
,
pipeline
from
..pipelines
import
Pipeline
,
PipelineDataFormat
,
get_supported_tasks
,
pipeline
from
..utils
import
logging
from
.
import
BaseTransformersCLICommand
...
...
@@ -63,9 +63,7 @@ class RunCommand(BaseTransformersCLICommand):
@
staticmethod
def
register_subcommand
(
parser
:
ArgumentParser
):
run_parser
=
parser
.
add_parser
(
"run"
,
help
=
"Run a pipeline through the CLI"
)
run_parser
.
add_argument
(
"--task"
,
choices
=
list
(
SUPPORTED_TASKS
.
keys
())
+
list
(
TASK_ALIASES
.
keys
()),
help
=
"Task to run"
)
run_parser
.
add_argument
(
"--task"
,
choices
=
get_supported_tasks
(),
help
=
"Task to run"
)
run_parser
.
add_argument
(
"--input"
,
type
=
str
,
help
=
"Path to the file to use for inference"
)
run_parser
.
add_argument
(
"--output"
,
type
=
str
,
help
=
"Path to the file that will be used post to write results."
)
run_parser
.
add_argument
(
"--model"
,
type
=
str
,
help
=
"Name or path to the model to instantiate."
)
...
...
src/transformers/commands/serving.py
View file @
c17e7cde
...
...
@@ -15,7 +15,7 @@
from
argparse
import
ArgumentParser
,
Namespace
from
typing
import
Any
,
List
,
Optional
from
..pipelines
import
SUPPORTED_TASKS
,
TASK_ALIASES
,
Pipeline
,
pipeline
from
..pipelines
import
Pipeline
,
get_supported_tasks
,
pipeline
from
..utils
import
logging
from
.
import
BaseTransformersCLICommand
...
...
@@ -104,7 +104,7 @@ class ServeCommand(BaseTransformersCLICommand):
serve_parser
.
add_argument
(
"--task"
,
type
=
str
,
choices
=
list
(
SUPPORTED_TASKS
.
keys
())
+
list
(
TASK_ALIASES
.
key
s
()
)
,
choices
=
get_supported_task
s
(),
help
=
"The task to run the pipeline on"
,
)
serve_parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"localhost"
,
help
=
"Interface the server will listen on."
)
...
...
src/transformers/pipelines/__init__.py
View file @
c17e7cde
...
...
@@ -20,7 +20,7 @@ import json
# See the License for the specific language governing permissions and
# limitations under the License.
import
warnings
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
..configuration_utils
import
PretrainedConfig
from
..feature_extraction_utils
import
PreTrainedFeatureExtractor
...
...
@@ -252,6 +252,15 @@ SUPPORTED_TASKS = {
}
def
get_supported_tasks
()
->
List
[
str
]:
"""
Returns a list of supported task strings.
"""
supported_tasks
=
list
(
SUPPORTED_TASKS
.
keys
())
+
list
(
TASK_ALIASES
.
keys
())
supported_tasks
.
sort
()
return
supported_tasks
def
get_task
(
model
:
str
,
use_auth_token
:
Optional
[
str
]
=
None
)
->
str
:
tmp
=
io
.
BytesIO
()
headers
=
{}
...
...
@@ -320,9 +329,7 @@ def check_task(task: str) -> Tuple[Dict, Any]:
return
targeted_task
,
(
tokens
[
1
],
tokens
[
3
])
raise
KeyError
(
f
"Invalid translation task
{
task
}
, use 'translation_XX_to_YY' format"
)
raise
KeyError
(
f
"Unknown task
{
task
}
, available tasks are
{
list
(
SUPPORTED_TASKS
.
keys
())
+
[
'translation_XX_to_YY'
]
}
"
)
raise
KeyError
(
f
"Unknown task
{
task
}
, available tasks are
{
get_supported_tasks
()
+
[
'translation_XX_to_YY'
]
}
"
)
def
pipeline
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment