"examples/summarization/test_summarization_examples.py" did not exist on "0373b60c4cf95efc0f24b868448be8e393ecac20"
Unverified Commit 0e1fce3c authored by Anthony MOI's avatar Anthony MOI Committed by GitHub
Browse files

Fix convert_graph_to_onnx (#5230)

parent 5543efd5
......@@ -114,15 +114,21 @@ def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], D
return input_vars, output_names, dynamic_axes, tokens
def load_graph_from_args(framework: str, model: str, tokenizer: Optional[str] = None) -> Pipeline:
def load_graph_from_args(pipeline_name: str, framework: str, model: str, tokenizer: Optional[str] = None) -> Pipeline:
# If no tokenizer provided
if tokenizer is None:
tokenizer = model
# Check the wanted framework is available
if framework == "pt" and not is_torch_available():
raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.")
if framework == "tf" and not is_tf_available():
raise Exception("Cannot convert because TF is not installed. Please install tensorflow first.")
print("Loading pipeline (model: {}, tokenizer: {})".format(model, tokenizer))
# Allocate tokenizer and model
return pipeline(args.pipeline, model=model, tokenizer=tokenizer, framework=framework)
return pipeline(pipeline_name, model=model, tokenizer=tokenizer, framework=framework)
def convert_pytorch(nlp: Pipeline, opset: int, output: str, use_external_format: bool):
......@@ -154,9 +160,7 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: str, use_external_format:
def convert_tensorflow(nlp: Pipeline, opset: int, output: str):
if not is_tf_available():
raise Exception(
"Cannot convert {} because TF is not installed. Please install torch first.".format(args.model)
)
raise Exception("Cannot convert because TF is not installed. Please install tensorflow first.")
print("/!\\ Please note TensorFlow doesn't support exporting model > 2Gb /!\\")
......@@ -187,11 +191,12 @@ def convert(
opset: int,
tokenizer: Optional[str] = None,
use_external_format: bool = False,
pipeline_name: str = "feature-extraction",
):
print("ONNX opset version set to: {}".format(opset))
# Load the pipeline
nlp = load_graph_from_args(framework, model, tokenizer)
nlp = load_graph_from_args(pipeline_name, framework, model, tokenizer)
parent = dirname(output)
if not exists(parent):
......@@ -229,7 +234,15 @@ if __name__ == "__main__":
try:
# Convert
convert(args.framework, args.model, args.output, args.opset, args.tokenizer, args.use_external_format)
convert(
args.framework,
args.model,
args.output,
args.opset,
args.tokenizer,
args.use_external_format,
args.pipeline,
)
# And verify
if args.check_loading:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment