Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
b6780b59
Commit
b6780b59
authored
Jun 30, 2022
by
Scott Zhu
Committed by
A. Unique TensorFlower
Jun 30, 2022
Browse files
Address a few legacy issues in the test.
PiperOrigin-RevId: 458268967
parent
08a9f1f8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
8 deletions
+2
-8
official/nlp/modeling/layers/tn_expand_condense_test.py
official/nlp/modeling/layers/tn_expand_condense_test.py
+2
-8
No files found.
official/nlp/modeling/layers/tn_expand_condense_test.py
View file @
b6780b59
...
...
@@ -19,8 +19,6 @@ import os
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
# pylint: disable=g-direct-tensorflow-import
from
tensorflow.python.keras.testing_utils
import
layer_test
from
official.nlp.modeling.layers.tn_expand_condense
import
TNExpandCondense
...
...
@@ -45,13 +43,9 @@ class TNLayerTest(tf.test.TestCase, parameterized.TestCase):
@
parameterized
.
parameters
((
768
,
6
),
(
1024
,
2
))
def
test_keras_layer
(
self
,
input_dim
,
proj_multiple
):
self
.
skipTest
(
'Disable the test for now since it imports '
'keras.testing_utils, will reenable this test after we '
'fix the b/184578869'
)
# TODO(scottzhu): Reenable after fix b/184578869
data
=
np
.
random
.
normal
(
size
=
(
100
,
input_dim
))
data
=
data
.
astype
(
np
.
float32
)
layer_test
(
tf
.
keras
.
__internal__
.
utils
.
layer_test
(
TNExpandCondense
,
kwargs
=
{
'proj_multiplier'
:
proj_multiple
,
...
...
@@ -64,9 +58,9 @@ class TNLayerTest(tf.test.TestCase, parameterized.TestCase):
@
parameterized
.
parameters
((
768
,
6
),
(
1024
,
2
))
def
test_train
(
self
,
input_dim
,
proj_multiple
):
tf
.
keras
.
utils
.
set_random_seed
(
0
)
data
=
np
.
random
.
randint
(
10
,
size
=
(
100
,
input_dim
))
model
=
self
.
_build_model
(
data
,
proj_multiple
)
tf
.
keras
.
utils
.
set_random_seed
(
0
)
model
.
compile
(
optimizer
=
'adam'
,
loss
=
'binary_crossentropy'
,
metrics
=
[
'accuracy'
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment