"examples/git@developer.sourcefind.cn:OpenDAS/apex.git" did not exist on "85b56d01c4bd9b19a6b791918722878288d00da3"
Unverified Commit c17e7cde authored by Suzen Fylke's avatar Suzen Fylke Committed by GitHub
Browse files

Add ability to get a list of supported pipeline tasks (#14732)

parent 3d66146a
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
from argparse import ArgumentParser from argparse import ArgumentParser
from ..pipelines import SUPPORTED_TASKS, TASK_ALIASES, Pipeline, PipelineDataFormat, pipeline from ..pipelines import Pipeline, PipelineDataFormat, get_supported_tasks, pipeline
from ..utils import logging from ..utils import logging
from . import BaseTransformersCLICommand from . import BaseTransformersCLICommand
...@@ -63,9 +63,7 @@ class RunCommand(BaseTransformersCLICommand): ...@@ -63,9 +63,7 @@ class RunCommand(BaseTransformersCLICommand):
@staticmethod @staticmethod
def register_subcommand(parser: ArgumentParser): def register_subcommand(parser: ArgumentParser):
run_parser = parser.add_parser("run", help="Run a pipeline through the CLI") run_parser = parser.add_parser("run", help="Run a pipeline through the CLI")
run_parser.add_argument( run_parser.add_argument("--task", choices=get_supported_tasks(), help="Task to run")
"--task", choices=list(SUPPORTED_TASKS.keys()) + list(TASK_ALIASES.keys()), help="Task to run"
)
run_parser.add_argument("--input", type=str, help="Path to the file to use for inference") run_parser.add_argument("--input", type=str, help="Path to the file to use for inference")
run_parser.add_argument("--output", type=str, help="Path to the file that will be used post to write results.") run_parser.add_argument("--output", type=str, help="Path to the file that will be used post to write results.")
run_parser.add_argument("--model", type=str, help="Name or path to the model to instantiate.") run_parser.add_argument("--model", type=str, help="Name or path to the model to instantiate.")
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
from argparse import ArgumentParser, Namespace from argparse import ArgumentParser, Namespace
from typing import Any, List, Optional from typing import Any, List, Optional
from ..pipelines import SUPPORTED_TASKS, TASK_ALIASES, Pipeline, pipeline from ..pipelines import Pipeline, get_supported_tasks, pipeline
from ..utils import logging from ..utils import logging
from . import BaseTransformersCLICommand from . import BaseTransformersCLICommand
...@@ -104,7 +104,7 @@ class ServeCommand(BaseTransformersCLICommand): ...@@ -104,7 +104,7 @@ class ServeCommand(BaseTransformersCLICommand):
serve_parser.add_argument( serve_parser.add_argument(
"--task", "--task",
type=str, type=str,
choices=list(SUPPORTED_TASKS.keys()) + list(TASK_ALIASES.keys()), choices=get_supported_tasks(),
help="The task to run the pipeline on", help="The task to run the pipeline on",
) )
serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.") serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.")
......
...@@ -20,7 +20,7 @@ import json ...@@ -20,7 +20,7 @@ import json
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import warnings import warnings
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from ..configuration_utils import PretrainedConfig from ..configuration_utils import PretrainedConfig
from ..feature_extraction_utils import PreTrainedFeatureExtractor from ..feature_extraction_utils import PreTrainedFeatureExtractor
...@@ -252,6 +252,15 @@ SUPPORTED_TASKS = { ...@@ -252,6 +252,15 @@ SUPPORTED_TASKS = {
} }
def get_supported_tasks() -> List[str]:
"""
Returns a list of supported task strings.
"""
supported_tasks = list(SUPPORTED_TASKS.keys()) + list(TASK_ALIASES.keys())
supported_tasks.sort()
return supported_tasks
def get_task(model: str, use_auth_token: Optional[str] = None) -> str: def get_task(model: str, use_auth_token: Optional[str] = None) -> str:
tmp = io.BytesIO() tmp = io.BytesIO()
headers = {} headers = {}
...@@ -320,9 +329,7 @@ def check_task(task: str) -> Tuple[Dict, Any]: ...@@ -320,9 +329,7 @@ def check_task(task: str) -> Tuple[Dict, Any]:
return targeted_task, (tokens[1], tokens[3]) return targeted_task, (tokens[1], tokens[3])
raise KeyError(f"Invalid translation task {task}, use 'translation_XX_to_YY' format") raise KeyError(f"Invalid translation task {task}, use 'translation_XX_to_YY' format")
raise KeyError( raise KeyError(f"Unknown task {task}, available tasks are {get_supported_tasks() + ['translation_XX_to_YY']}")
f"Unknown task {task}, available tasks are {list(SUPPORTED_TASKS.keys()) + ['translation_XX_to_YY']}"
)
def pipeline( def pipeline(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment