Unverified Commit 7b685f52 authored by Funtowicz Morgan's avatar Funtowicz Morgan Committed by GitHub
Browse files

Increase pipeline support for ONNX export. (#5005)

* Increase pipeline support for ONNX export.

* Style.
parent 1affde2f
...@@ -8,6 +8,19 @@ from transformers.pipelines import Pipeline, pipeline ...@@ -8,6 +8,19 @@ from transformers.pipelines import Pipeline, pipeline
from transformers.tokenization_utils import BatchEncoding from transformers.tokenization_utils import BatchEncoding
SUPPORTED_PIPELINES = [
"feature-extraction",
"ner",
"sentiment-analysis",
"fill-mask",
"question-answering",
"text-generation",
"translation_en_to_fr",
"translation_en_to_de",
"translation_en_to_ro",
]
class OnnxConverterArgumentParser(ArgumentParser): class OnnxConverterArgumentParser(ArgumentParser):
""" """
Wraps all the script arguments supported to export transformers models to ONNX IR Wraps all the script arguments supported to export transformers models to ONNX IR
...@@ -16,6 +29,7 @@ class OnnxConverterArgumentParser(ArgumentParser): ...@@ -16,6 +29,7 @@ class OnnxConverterArgumentParser(ArgumentParser):
def __init__(self): def __init__(self):
super(OnnxConverterArgumentParser, self).__init__("ONNX Converter") super(OnnxConverterArgumentParser, self).__init__("ONNX Converter")
self.add_argument("--pipeline", type=str, choices=SUPPORTED_PIPELINES, default="feature-extraction")
self.add_argument("--model", type=str, required=True, help="Model's id or path (ex: bert-base-cased)") self.add_argument("--model", type=str, required=True, help="Model's id or path (ex: bert-base-cased)")
self.add_argument("--tokenizer", type=str, help="Tokenizer's id or path (ex: bert-base-cased)") self.add_argument("--tokenizer", type=str, help="Tokenizer's id or path (ex: bert-base-cased)")
self.add_argument("--framework", type=str, choices=["pt", "tf"], help="Framework for loading the model") self.add_argument("--framework", type=str, choices=["pt", "tf"], help="Framework for loading the model")
...@@ -108,7 +122,7 @@ def load_graph_from_args(framework: str, model: str, tokenizer: Optional[str] = ...@@ -108,7 +122,7 @@ def load_graph_from_args(framework: str, model: str, tokenizer: Optional[str] =
print("Loading pipeline (model: {}, tokenizer: {})".format(model, tokenizer)) print("Loading pipeline (model: {}, tokenizer: {})".format(model, tokenizer))
# Allocate tokenizer and model # Allocate tokenizer and model
return pipeline("feature-extraction", model=model, tokenizer=tokenizer, framework=framework) return pipeline(args.pipeline, model=model, tokenizer=tokenizer, framework=framework)
def convert_pytorch(nlp: Pipeline, opset: int, output: str, use_external_format: bool): def convert_pytorch(nlp: Pipeline, opset: int, output: str, use_external_format: bool):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment