Unverified Commit 0e1fce3c authored by Anthony MOI's avatar Anthony MOI Committed by GitHub
Browse files

Fix convert_graph_to_onnx (#5230)

parent 5543efd5
...@@ -114,15 +114,21 @@ def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], D ...@@ -114,15 +114,21 @@ def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], D
return input_vars, output_names, dynamic_axes, tokens return input_vars, output_names, dynamic_axes, tokens
def load_graph_from_args(framework: str, model: str, tokenizer: Optional[str] = None) -> Pipeline: def load_graph_from_args(pipeline_name: str, framework: str, model: str, tokenizer: Optional[str] = None) -> Pipeline:
# If no tokenizer provided # If no tokenizer provided
if tokenizer is None: if tokenizer is None:
tokenizer = model tokenizer = model
# Check the wanted framework is available
if framework == "pt" and not is_torch_available():
raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.")
if framework == "tf" and not is_tf_available():
raise Exception("Cannot convert because TF is not installed. Please install tensorflow first.")
print("Loading pipeline (model: {}, tokenizer: {})".format(model, tokenizer)) print("Loading pipeline (model: {}, tokenizer: {})".format(model, tokenizer))
# Allocate tokenizer and model # Allocate tokenizer and model
return pipeline(args.pipeline, model=model, tokenizer=tokenizer, framework=framework) return pipeline(pipeline_name, model=model, tokenizer=tokenizer, framework=framework)
def convert_pytorch(nlp: Pipeline, opset: int, output: str, use_external_format: bool): def convert_pytorch(nlp: Pipeline, opset: int, output: str, use_external_format: bool):
...@@ -154,9 +160,7 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: str, use_external_format: ...@@ -154,9 +160,7 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: str, use_external_format:
def convert_tensorflow(nlp: Pipeline, opset: int, output: str): def convert_tensorflow(nlp: Pipeline, opset: int, output: str):
if not is_tf_available(): if not is_tf_available():
raise Exception( raise Exception("Cannot convert because TF is not installed. Please install tensorflow first.")
"Cannot convert {} because TF is not installed. Please install torch first.".format(args.model)
)
print("/!\\ Please note TensorFlow doesn't support exporting model > 2Gb /!\\") print("/!\\ Please note TensorFlow doesn't support exporting model > 2Gb /!\\")
...@@ -187,11 +191,12 @@ def convert( ...@@ -187,11 +191,12 @@ def convert(
opset: int, opset: int,
tokenizer: Optional[str] = None, tokenizer: Optional[str] = None,
use_external_format: bool = False, use_external_format: bool = False,
pipeline_name: str = "feature-extraction",
): ):
print("ONNX opset version set to: {}".format(opset)) print("ONNX opset version set to: {}".format(opset))
# Load the pipeline # Load the pipeline
nlp = load_graph_from_args(framework, model, tokenizer) nlp = load_graph_from_args(pipeline_name, framework, model, tokenizer)
parent = dirname(output) parent = dirname(output)
if not exists(parent): if not exists(parent):
...@@ -229,7 +234,15 @@ if __name__ == "__main__": ...@@ -229,7 +234,15 @@ if __name__ == "__main__":
try: try:
# Convert # Convert
convert(args.framework, args.model, args.output, args.opset, args.tokenizer, args.use_external_format) convert(
args.framework,
args.model,
args.output,
args.opset,
args.tokenizer,
args.use_external_format,
args.pipeline,
)
# And verify # And verify
if args.check_loading: if args.check_loading:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment