Unverified Commit c8d3fa0d authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Check TF ops for ONNX compliance (#10025)



* Add check-ops script

* Finish to implement check_tf_ops and start the test

* Make the test mandatory only for BERT

* Update tf_ops folder

* Remove useless classes

* Add the ONNX test for GPT2 and BART

* Add a onnxruntime slow test + better opset flexibility

* Fix test + apply style

* fix tests

* Switch min opset from 12 to 10

* Update src/transformers/file_utils.py
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* Fix GPT2

* Remove extra shape_list usage

* Fix GPT2

* Address Morgan's comments
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
parent 93bd2f70
...@@ -179,6 +179,7 @@ class TFMarianModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -179,6 +179,7 @@ class TFMarianModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFMarianMTModel,) if is_tf_available() else () all_generative_model_classes = (TFMarianMTModel,) if is_tf_available() else ()
is_encoder_decoder = True is_encoder_decoder = True
test_pruning = False test_pruning = False
test_onnx = False
def setUp(self): def setUp(self):
self.model_tester = TFMarianModelTester(self) self.model_tester = TFMarianModelTester(self)
......
...@@ -181,6 +181,7 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -181,6 +181,7 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFMBartForConditionalGeneration,) if is_tf_available() else () all_generative_model_classes = (TFMBartForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True is_encoder_decoder = True
test_pruning = False test_pruning = False
test_onnx = False
def setUp(self): def setUp(self):
self.model_tester = TFMBartModelTester(self) self.model_tester = TFMBartModelTester(self)
......
...@@ -56,6 +56,7 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -56,6 +56,7 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase):
else () else ()
) )
test_head_masking = False test_head_masking = False
test_onnx = False
class TFMobileBertModelTester(object): class TFMobileBertModelTester(object):
def __init__( def __init__(
......
...@@ -199,6 +199,7 @@ class TFMPNetModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -199,6 +199,7 @@ class TFMPNetModelTest(TFModelTesterMixin, unittest.TestCase):
else () else ()
) )
test_head_masking = False test_head_masking = False
test_onnx = False
def setUp(self): def setUp(self):
self.model_tester = TFMPNetModelTester(self) self.model_tester = TFMPNetModelTester(self)
......
...@@ -203,6 +203,7 @@ class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -203,6 +203,7 @@ class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase):
(TFOpenAIGPTLMHeadModel,) if is_tf_available() else () (TFOpenAIGPTLMHeadModel,) if is_tf_available() else ()
) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly ) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly
test_head_masking = False test_head_masking = False
test_onnx = False
def setUp(self): def setUp(self):
self.model_tester = TFOpenAIGPTModelTester(self) self.model_tester = TFOpenAIGPTModelTester(self)
......
...@@ -177,6 +177,7 @@ class TFPegasusModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -177,6 +177,7 @@ class TFPegasusModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFPegasusForConditionalGeneration,) if is_tf_available() else () all_generative_model_classes = (TFPegasusForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True is_encoder_decoder = True
test_pruning = False test_pruning = False
test_onnx = False
def setUp(self): def setUp(self):
self.model_tester = TFPegasusModelTester(self) self.model_tester = TFPegasusModelTester(self)
......
...@@ -186,6 +186,7 @@ class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -186,6 +186,7 @@ class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase):
else () else ()
) )
test_head_masking = False test_head_masking = False
test_onnx = False
def setUp(self): def setUp(self):
self.model_tester = TFRobertaModelTester(self) self.model_tester = TFRobertaModelTester(self)
......
...@@ -249,6 +249,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -249,6 +249,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFT5Model, TFT5ForConditionalGeneration) if is_tf_available() else () all_model_classes = (TFT5Model, TFT5ForConditionalGeneration) if is_tf_available() else ()
all_generative_model_classes = (TFT5ForConditionalGeneration,) if is_tf_available() else () all_generative_model_classes = (TFT5ForConditionalGeneration,) if is_tf_available() else ()
test_head_masking = False test_head_masking = False
test_onnx = False
def setUp(self): def setUp(self):
self.model_tester = TFT5ModelTester(self) self.model_tester = TFT5ModelTester(self)
...@@ -427,6 +428,7 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -427,6 +428,7 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
is_encoder_decoder = False is_encoder_decoder = False
all_model_classes = (TFT5EncoderModel,) if is_tf_available() else () all_model_classes = (TFT5EncoderModel,) if is_tf_available() else ()
test_head_masking = False test_head_masking = False
test_onnx = False
def setUp(self): def setUp(self):
self.model_tester = TFT5EncoderOnlyModelTester(self) self.model_tester = TFT5EncoderOnlyModelTester(self)
......
...@@ -164,6 +164,7 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -164,6 +164,7 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO: add this test when TFTransfoXLLMHead has a linear output layer implemented # TODO: add this test when TFTransfoXLLMHead has a linear output layer implemented
test_resize_embeddings = False test_resize_embeddings = False
test_head_masking = False test_head_masking = False
test_onnx = False
def setUp(self): def setUp(self):
self.model_tester = TFTransfoXLModelTester(self) self.model_tester = TFTransfoXLModelTester(self)
......
...@@ -294,6 +294,7 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -294,6 +294,7 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase):
(TFXLMWithLMHeadModel,) if is_tf_available() else () (TFXLMWithLMHeadModel,) if is_tf_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable ) # TODO (PVP): Check other models whether language generation is also applicable
test_head_masking = False test_head_masking = False
test_onnx = False
def setUp(self): def setUp(self):
self.model_tester = TFXLMModelTester(self) self.model_tester = TFXLMModelTester(self)
......
...@@ -348,6 +348,7 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -348,6 +348,7 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
(TFXLNetLMHeadModel,) if is_tf_available() else () (TFXLNetLMHeadModel,) if is_tf_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable ) # TODO (PVP): Check other models whether language generation is also applicable
test_head_masking = False test_head_masking = False
test_onnx = False
def setUp(self): def setUp(self):
self.model_tester = TFXLNetModelTester(self) self.model_tester = TFXLNetModelTester(self)
......
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
from tensorflow.core.protobuf.saved_model_pb2 import SavedModel
# All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_copies.py
REPO_PATH = "."
# Internal TensorFlow ops that can be safely ignored (mostly specific to a saved model)
INTERNAL_OPS = [
"Assert",
"AssignVariableOp",
"EmptyTensorList",
"MergeV2Checkpoints",
"ReadVariableOp",
"ResourceGather",
"RestoreV2",
"SaveV2",
"ShardedFilename",
"StatefulPartitionedCall",
"StaticRegexFullMatch",
"VarHandleOp",
]
def onnx_compliancy(saved_model_path, strict, opset):
saved_model = SavedModel()
onnx_ops = []
with open(os.path.join(REPO_PATH, "utils", "tf_ops", "onnx.json")) as f:
onnx_opsets = json.load(f)["opsets"]
for i in range(1, opset + 1):
onnx_ops.extend(onnx_opsets[str(i)])
with open(saved_model_path, "rb") as f:
saved_model.ParseFromString(f.read())
model_op_names = set()
# Iterate over every metagraph in case there is more than one (a saved model can contain multiple graphs)
for meta_graph in saved_model.meta_graphs:
# Add operations in the graph definition
model_op_names.update(node.op for node in meta_graph.graph_def.node)
# Go through the functions in the graph definition
for func in meta_graph.graph_def.library.function:
# Add operations in each function
model_op_names.update(node.op for node in func.node_def)
# Convert to list, sorted if you want
model_op_names = sorted(model_op_names)
incompatible_ops = []
for op in model_op_names:
if op not in onnx_ops and op not in INTERNAL_OPS:
incompatible_ops.append(op)
if strict and len(incompatible_ops) > 0:
raise Exception(f"Found the following incompatible ops for the opset {opset}:\n" + incompatible_ops)
elif len(incompatible_ops) > 0:
print(f"Found the following incompatible ops for the opset {opset}:")
print(*incompatible_ops, sep="\n")
else:
print(f"The saved model {saved_model_path} can properly be converted with ONNX.")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--saved_model_path", help="Path of the saved model to check (the .pb file).")
parser.add_argument(
"--opset", default=12, type=int, help="The ONNX opset against which the model has to be tested."
)
parser.add_argument(
"--framework", choices=["onnx"], default="onnx", help="Frameworks against which to test the saved model."
)
parser.add_argument(
"--strict", action="store_true", help="Whether make the checking strict (raise errors) or not (raise warnings)"
)
args = parser.parse_args()
if args.framework == "onnx":
onnx_compliancy(args.saved_model_path, args.strict, args.opset)
{
"opsets": {
"1": [
"Abs",
"Add",
"AddV2",
"ArgMax",
"ArgMin",
"AvgPool",
"AvgPool3D",
"BatchMatMul",
"BatchMatMulV2",
"BatchToSpaceND",
"BiasAdd",
"BiasAddV1",
"Cast",
"Ceil",
"CheckNumerics",
"ComplexAbs",
"Concat",
"ConcatV2",
"Const",
"ConstV2",
"Conv1D",
"Conv2D",
"Conv2DBackpropInput",
"Conv3D",
"Conv3DBackpropInputV2",
"DepthToSpace",
"DepthwiseConv2d",
"DepthwiseConv2dNative",
"Div",
"Dropout",
"Elu",
"Equal",
"Erf",
"Exp",
"ExpandDims",
"Flatten",
"Floor",
"Gather",
"GatherNd",
"GatherV2",
"Greater",
"Identity",
"IdentityN",
"If",
"LRN",
"LSTMBlockCell",
"LeakyRelu",
"Less",
"Log",
"LogSoftmax",
"LogicalAnd",
"LogicalNot",
"LogicalOr",
"LookupTableSizeV2",
"MatMul",
"Max",
"MaxPool",
"MaxPool3D",
"MaxPoolV2",
"Maximum",
"Mean",
"Min",
"Minimum",
"MirrorPad",
"Mul",
"Neg",
"NoOp",
"NotEqual",
"OneHot",
"Pack",
"Pad",
"PadV2",
"Placeholder",
"PlaceholderV2",
"PlaceholderWithDefault",
"Pow",
"Prod",
"RFFT",
"RandomNormal",
"RandomNormalLike",
"RandomUniform",
"RandomUniformLike",
"RealDiv",
"Reciprocal",
"Relu",
"Relu6",
"Reshape",
"Rsqrt",
"Selu",
"Shape",
"Sigmoid",
"Sign",
"Size",
"Slice",
"Softmax",
"Softplus",
"Softsign",
"SpaceToBatchND",
"SpaceToDepth",
"Split",
"SplitV",
"Sqrt",
"Square",
"SquaredDifference",
"Squeeze",
"StatelessIf",
"StopGradient",
"StridedSlice",
"StringJoin",
"Sub",
"Sum",
"Tanh",
"Tile",
"TopKV2",
"Transpose",
"TruncateDiv",
"Unpack",
"ZerosLike"
],
"2": [],
"3": [],
"4": [],
"5": [],
"6": [
"AddN",
"All",
"Any",
"FloorDiv",
"FusedBatchNorm",
"FusedBatchNormV2",
"FusedBatchNormV3"
],
"7": [
"Acos",
"Asin",
"Atan",
"Cos",
"Fill",
"FloorMod",
"GreaterEqual",
"LessEqual",
"Loop",
"MatrixBandPart",
"Multinomial",
"Range",
"ResizeBilinear",
"ResizeNearestNeighbor",
"Scan",
"Select",
"SelectV2",
"Sin",
"SoftmaxCrossEntropyWithLogits",
"SparseSoftmaxCrossEntropyWithLogits",
"StatelessWhile",
"Tan",
"TensorListFromTensor",
"TensorListGetItem",
"TensorListLength",
"TensorListReserve",
"TensorListResize",
"TensorListSetItem",
"TensorListStack",
"While"
],
"8": [
"BroadcastTo",
"ClipByValue",
"FIFOQueueV2",
"HashTableV2",
"IteratorGetNext",
"IteratorV2",
"LookupTableFindV2",
"MaxPoolWithArgmax",
"QueueDequeueManyV2",
"QueueDequeueUpToV2",
"QueueDequeueV2",
"ReverseSequence"
],
"9": [
"SegmentMax",
"SegmentMean",
"SegmentMin",
"SegmentProd",
"SegmentSum",
"Sinh",
"SparseSegmentMean",
"SparseSegmentMeanWithNumSegments",
"SparseSegmentSqrtN",
"SparseSegmentSqrtNWithNumSegments",
"SparseSegmentSum",
"SparseSegmentSumWithNumSegments",
"UnsortedSegmentMax",
"UnsortedSegmentMin",
"UnsortedSegmentProd",
"UnsortedSegmentSum",
"Where"
],
"10": [
"CropAndResize",
"CudnnRNN",
"DynamicStitch",
"FakeQuantWithMinMaxArgs",
"IsFinite",
"IsInf",
"NonMaxSuppressionV2",
"NonMaxSuppressionV3",
"NonMaxSuppressionV4",
"NonMaxSuppressionV5",
"ParallelDynamicStitch",
"ReverseV2",
"Roll"
],
"11": [
"Bincount",
"Cumsum",
"InvertPermutation",
"LeftShift",
"MatrixDeterminant",
"MatrixDiagPart",
"MatrixDiagPartV2",
"MatrixDiagPartV3",
"RaggedRange",
"RightShift",
"Round",
"ScatterNd",
"SparseFillEmptyRows",
"SparseReshape",
"SparseToDense",
"TensorScatterUpdate",
"Unique"
],
"12": [
"Einsum",
"MatrixDiag",
"MatrixDiagV2",
"MatrixDiagV3",
"MatrixSetDiagV3",
"SquaredDistance"
],
"13": []
}
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment