Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
401fcca6
Unverified
Commit
401fcca6
authored
Jun 27, 2022
by
Yih-Dar
Committed by
GitHub
Jun 27, 2022
Browse files
Fix TF GPT2 test_onnx_runtime_optimize (#17874)
Co-authored-by:
ydshieh
<
ydshieh@users.noreply.github.com
>
parent
cc5c061e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
1 deletion
+27
-1
tests/models/gpt2/test_modeling_tf_gpt2.py
tests/models/gpt2/test_modeling_tf_gpt2.py
+27
-1
No files found.
tests/models/gpt2/test_modeling_tf_gpt2.py
View file @
401fcca6
...
...
@@ -16,7 +16,7 @@
import
unittest
from
transformers
import
GPT2Config
,
is_tf_available
from
transformers.testing_utils
import
require_tf
,
slow
from
transformers.testing_utils
import
require_tf
,
require_tf2onnx
,
slow
from
...test_configuration_common
import
ConfigTester
from
...test_modeling_tf_common
import
TFModelTesterMixin
,
floats_tensor
,
ids_tensor
,
random_attention_mask
...
...
@@ -444,6 +444,32 @@ class TFGPT2ModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC
model
=
TFGPT2Model
.
from_pretrained
(
model_name
)
self
.
assertIsNotNone
(
model
)
# overwrite from common since ONNX runtime optimization doesn't work with tf.gather() when the argument
# `batch_dims` > 0"
@
require_tf2onnx
@
slow
def
test_onnx_runtime_optimize
(
self
):
if
not
self
.
test_onnx
:
return
import
onnxruntime
import
tf2onnx
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
for
model_class
in
self
.
all_model_classes
:
# Skip these 2 classes which uses `tf.gather` with `batch_dims=1`
if
model_class
in
[
TFGPT2ForSequenceClassification
,
TFGPT2DoubleHeadsModel
]:
continue
model
=
model_class
(
config
)
model
(
model
.
dummy_inputs
)
onnx_model_proto
,
_
=
tf2onnx
.
convert
.
from_keras
(
model
,
opset
=
self
.
onnx_min_opset
)
onnxruntime
.
InferenceSession
(
onnx_model_proto
.
SerializeToString
())
@
require_tf
class
TFGPT2ModelLanguageGenerationTest
(
unittest
.
TestCase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment