Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
7e370ba7
Commit
7e370ba7
authored
Apr 23, 2021
by
Chen Chen
Committed by
A. Unique TensorFlower
Apr 23, 2021
Browse files
Internal change
PiperOrigin-RevId: 370159521
parent
76145d74
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
33 additions
and
10 deletions
+33
-10
official/nlp/modeling/networks/encoder_scaffold.py
official/nlp/modeling/networks/encoder_scaffold.py
+10
-1
official/nlp/modeling/networks/encoder_scaffold_test.py
official/nlp/modeling/networks/encoder_scaffold_test.py
+23
-9
No files found.
official/nlp/modeling/networks/encoder_scaffold.py
View file @
7e370ba7
...
...
@@ -257,11 +257,20 @@ class EncoderScaffold(tf.keras.Model):
'pooler_layer_initializer'
:
self
.
_pooler_layer_initializer
,
'embedding_cls'
:
self
.
_embedding_network
,
'embedding_cfg'
:
self
.
_embedding_cfg
,
'hidden_cfg'
:
self
.
_hidden_cfg
,
'layer_norm_before_pooling'
:
self
.
_layer_norm_before_pooling
,
'return_all_layer_outputs'
:
self
.
_return_all_layer_outputs
,
'dict_outputs'
:
self
.
_dict_outputs
,
}
if
self
.
_hidden_cfg
:
config_dict
[
'hidden_cfg'
]
=
{}
for
k
,
v
in
self
.
_hidden_cfg
.
items
():
# `self._hidden_cfg` may contain `class`, e.g., when `hidden_cfg` is
# `TransformerScaffold`, its `attention_cls` argument can be a `class`.
if
inspect
.
isclass
(
v
):
config_dict
[
'hidden_cfg'
][
k
]
=
tf
.
keras
.
utils
.
get_registered_name
(
v
)
else
:
config_dict
[
'hidden_cfg'
][
k
]
=
v
if
inspect
.
isclass
(
self
.
_hidden_cls
):
config_dict
[
'hidden_cls_string'
]
=
tf
.
keras
.
utils
.
get_registered_name
(
self
.
_hidden_cls
)
...
...
official/nlp/modeling/networks/encoder_scaffold_test.py
View file @
7e370ba7
...
...
@@ -31,9 +31,10 @@ from official.nlp.modeling.networks import encoder_scaffold
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"TestOnly"
)
class
ValidatedTransformerLayer
(
layers
.
Transformer
):
def
__init__
(
self
,
call_list
,
**
kwargs
):
def
__init__
(
self
,
call_list
,
call_class
=
None
,
**
kwargs
):
super
(
ValidatedTransformerLayer
,
self
).
__init__
(
**
kwargs
)
self
.
list
=
call_list
self
.
call_class
=
call_class
def
call
(
self
,
inputs
):
self
.
list
.
append
(
True
)
...
...
@@ -41,10 +42,16 @@ class ValidatedTransformerLayer(layers.Transformer):
def
get_config
(
self
):
config
=
super
(
ValidatedTransformerLayer
,
self
).
get_config
()
config
[
"call_list"
]
=
[]
config
[
"call_list"
]
=
self
.
list
config
[
"call_class"
]
=
tf
.
keras
.
utils
.
get_registered_name
(
self
.
call_class
)
return
config
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"TestLayerOnly"
)
class
TestLayer
(
tf
.
keras
.
layers
.
Layer
):
pass
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@
keras_parameterized
.
run_all_keras_modes
...
...
@@ -560,7 +567,8 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase):
self
.
assertNotEmpty
(
call_list
)
self
.
assertTrue
(
call_list
[
0
],
"The passed layer class wasn't instantiated."
)
def
test_serialize_deserialize
(
self
):
@
parameterized
.
parameters
(
True
,
False
)
def
test_serialize_deserialize
(
self
,
use_hidden_cls_instance
):
hidden_size
=
32
sequence_length
=
21
vocab_size
=
57
...
...
@@ -591,21 +599,27 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase):
"kernel_initializer"
:
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
),
"call_list"
:
call_list
call_list
,
"call_class"
:
TestLayer
}
# Create a small EncoderScaffold for testing. This time, we pass an already-
# instantiated layer object.
xformer
=
ValidatedTransformerLayer
(
**
hidden_cfg
)
test_network
=
encoder_scaffold
.
EncoderScaffold
(
kwargs
=
dict
(
num_hidden_instances
=
3
,
pooled_output_dim
=
hidden_size
,
pooler_layer_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
),
hidden_cls
=
xformer
,
embedding_cfg
=
embedding_cfg
)
if
use_hidden_cls_instance
:
xformer
=
ValidatedTransformerLayer
(
**
hidden_cfg
)
test_network
=
encoder_scaffold
.
EncoderScaffold
(
hidden_cls
=
xformer
,
**
kwargs
)
else
:
test_network
=
encoder_scaffold
.
EncoderScaffold
(
hidden_cls
=
ValidatedTransformerLayer
,
hidden_cfg
=
hidden_cfg
,
**
kwargs
)
# Create another network object from the first object's config.
new_network
=
encoder_scaffold
.
EncoderScaffold
.
from_config
(
test_network
.
get_config
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment