Unverified Commit 401fcca6 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix TF GPT2 test_onnx_runtime_optimize (#17874)


Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent cc5c061e
......@@ -16,7 +16,7 @@
import unittest
from transformers import GPT2Config, is_tf_available
from transformers.testing_utils import require_tf, slow
from transformers.testing_utils import require_tf, require_tf2onnx, slow
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
......@@ -444,6 +444,32 @@ class TFGPT2ModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC
model = TFGPT2Model.from_pretrained(model_name)
self.assertIsNotNone(model)
# overwrite from common since ONNX runtime optimization doesn't work with tf.gather() when the argument
# `batch_dims` > 0"
@require_tf2onnx
@slow
def test_onnx_runtime_optimize(self):
if not self.test_onnx:
return
import onnxruntime
import tf2onnx
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
# Skip these 2 classes which uses `tf.gather` with `batch_dims=1`
if model_class in [TFGPT2ForSequenceClassification, TFGPT2DoubleHeadsModel]:
continue
model = model_class(config)
model(model.dummy_inputs)
onnx_model_proto, _ = tf2onnx.convert.from_keras(model, opset=self.onnx_min_opset)
onnxruntime.InferenceSession(onnx_model_proto.SerializeToString())
@require_tf
class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment