Unverified Commit c17e7cde authored by Suzen Fylke's avatar Suzen Fylke Committed by GitHub
Browse files

Add ability to get a list of supported pipeline tasks (#14732)

parent 3d66146a
......@@ -14,7 +14,7 @@
from argparse import ArgumentParser
from ..pipelines import SUPPORTED_TASKS, TASK_ALIASES, Pipeline, PipelineDataFormat, pipeline
from ..pipelines import Pipeline, PipelineDataFormat, get_supported_tasks, pipeline
from ..utils import logging
from . import BaseTransformersCLICommand
......@@ -63,9 +63,7 @@ class RunCommand(BaseTransformersCLICommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
run_parser = parser.add_parser("run", help="Run a pipeline through the CLI")
run_parser.add_argument(
"--task", choices=list(SUPPORTED_TASKS.keys()) + list(TASK_ALIASES.keys()), help="Task to run"
)
run_parser.add_argument("--task", choices=get_supported_tasks(), help="Task to run")
run_parser.add_argument("--input", type=str, help="Path to the file to use for inference")
run_parser.add_argument("--output", type=str, help="Path to the file that will be used post to write results.")
run_parser.add_argument("--model", type=str, help="Name or path to the model to instantiate.")
......
......@@ -15,7 +15,7 @@
from argparse import ArgumentParser, Namespace
from typing import Any, List, Optional
from ..pipelines import SUPPORTED_TASKS, TASK_ALIASES, Pipeline, pipeline
from ..pipelines import Pipeline, get_supported_tasks, pipeline
from ..utils import logging
from . import BaseTransformersCLICommand
......@@ -104,7 +104,7 @@ class ServeCommand(BaseTransformersCLICommand):
serve_parser.add_argument(
"--task",
type=str,
choices=list(SUPPORTED_TASKS.keys()) + list(TASK_ALIASES.keys()),
choices=get_supported_tasks(),
help="The task to run the pipeline on",
)
serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.")
......
......@@ -20,7 +20,7 @@ import json
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from ..configuration_utils import PretrainedConfig
from ..feature_extraction_utils import PreTrainedFeatureExtractor
......@@ -252,6 +252,15 @@ SUPPORTED_TASKS = {
}
def get_supported_tasks() -> List[str]:
"""
Returns a list of supported task strings.
"""
supported_tasks = list(SUPPORTED_TASKS.keys()) + list(TASK_ALIASES.keys())
supported_tasks.sort()
return supported_tasks
def get_task(model: str, use_auth_token: Optional[str] = None) -> str:
tmp = io.BytesIO()
headers = {}
......@@ -320,9 +329,7 @@ def check_task(task: str) -> Tuple[Dict, Any]:
return targeted_task, (tokens[1], tokens[3])
raise KeyError(f"Invalid translation task {task}, use 'translation_XX_to_YY' format")
raise KeyError(
f"Unknown task {task}, available tasks are {list(SUPPORTED_TASKS.keys()) + ['translation_XX_to_YY']}"
)
raise KeyError(f"Unknown task {task}, available tasks are {get_supported_tasks() + ['translation_XX_to_YY']}")
def pipeline(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment