Unverified Commit 507601a5 authored by lewtun's avatar lewtun Committed by GitHub
Browse files

Prepare deprecated ONNX exporter for torch v1.11 (#15388)

* Prepare deprecated ONNX exporter for PyTorch v1.11

* Add deprecation warning
parent 4996922b
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import warnings
from argparse import ArgumentParser from argparse import ArgumentParser
from os import listdir, makedirs from os import listdir, makedirs
from pathlib import Path from pathlib import Path
...@@ -278,6 +279,9 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format ...@@ -278,6 +279,9 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format
input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "pt") input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "pt")
ordered_input_names, model_args = ensure_valid_input(nlp.model, tokens, input_names) ordered_input_names, model_args = ensure_valid_input(nlp.model, tokens, input_names)
# PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
# so we check the torch version for backwards compatibility
if parse(torch.__version__) <= parse("1.10.99"):
export( export(
nlp.model, nlp.model,
model_args, model_args,
...@@ -290,6 +294,17 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format ...@@ -290,6 +294,17 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format
enable_onnx_checker=True, enable_onnx_checker=True,
opset_version=opset, opset_version=opset,
) )
else:
export(
nlp.model,
model_args,
f=output.as_posix(),
input_names=ordered_input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
do_constant_folding=True,
opset_version=opset,
)
def convert_tensorflow(nlp: Pipeline, opset: int, output: Path): def convert_tensorflow(nlp: Pipeline, opset: int, output: Path):
...@@ -356,6 +371,10 @@ def convert( ...@@ -356,6 +371,10 @@ def convert(
Returns: Returns:
""" """
warnings.warn(
"The `transformers.convert_graph_to_onnx` package is deprecated and will be removed in version 5 of Transformers",
FutureWarning,
)
print(f"ONNX opset version set to: {opset}") print(f"ONNX opset version set to: {opset}")
# Load the pipeline # Load the pipeline
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment