Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
02f8d387
Commit
02f8d387
authored
Aug 28, 2020
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 328989927
parent
bc4ccd2f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
28 additions
and
9 deletions
+28
-9
official/nlp/modeling/models/dual_encoder.py
official/nlp/modeling/models/dual_encoder.py
+23
-8
official/nlp/modeling/models/dual_encoder_test.py
official/nlp/modeling/models/dual_encoder_test.py
+5
-1
No files found.
official/nlp/modeling/models/dual_encoder.py
View file @
02f8d387
...
...
@@ -67,15 +67,24 @@ class DualEncoder(tf.keras.Model):
self
.
network
=
network
left_word_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'left_word_ids'
)
left_mask
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'left_mask'
)
left_type_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'left_type_ids'
)
if
output
==
'logits'
:
left_word_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'left_word_ids'
)
left_mask
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'left_mask'
)
left_type_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'left_type_ids'
)
else
:
# Keep the consistant with legacy BERT hub module input names.
left_word_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'input_word_ids'
)
left_mask
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'input_mask'
)
left_type_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'input_type_ids'
)
left_inputs
=
[
left_word_ids
,
left_mask
,
left_type_ids
]
_
,
left_encoded
=
network
(
left_inputs
)
left_sequence_output
,
left_encoded
=
network
(
left_inputs
)
if
normalize
:
left_encoded
=
tf
.
keras
.
layers
.
Lambda
(
...
...
@@ -108,13 +117,19 @@ class DualEncoder(tf.keras.Model):
elif
output
==
'predictions'
:
inputs
=
[
left_word_ids
,
left_mask
,
left_type_ids
]
outputs
=
left_encoded
# To keep consistent with legacy BERT hub modules, the outputs are
# "pooled_output" and "sequence_output".
outputs
=
[
left_encoded
,
left_sequence_output
]
else
:
raise
ValueError
(
'output type %s is not supported'
%
output
)
super
(
DualEncoder
,
self
).
__init__
(
inputs
=
inputs
,
outputs
=
outputs
,
**
kwargs
)
# Set _self_setattr_tracking to True so it can be exported with assets.
self
.
_self_setattr_tracking
=
True
def
get_config
(
self
):
return
self
.
_config
...
...
official/nlp/modeling/models/dual_encoder_test.py
View file @
02f8d387
...
...
@@ -64,13 +64,17 @@ class DualEncoderTest(keras_parameterized.TestCase):
left_encoded
,
_
=
outputs
elif
output
==
'predictions'
:
left_encoded
=
dual_encoder_model
([
left_encoded
,
left_sequence_output
=
dual_encoder_model
([
left_word_ids
,
left_mask
,
left_type_ids
])
# Validate that the outputs are of the expected shape.
expected_encoding_shape
=
[
None
,
768
]
self
.
assertAllEqual
(
expected_encoding_shape
,
left_encoded
.
shape
.
as_list
())
expected_sequence_shape
=
[
None
,
sequence_length
,
768
]
self
.
assertAllEqual
(
expected_sequence_shape
,
left_sequence_output
.
shape
.
as_list
())
@
parameterized
.
parameters
((
192
,
'logits'
),
(
768
,
'predictions'
))
def
test_dual_encoder_tensor_call
(
self
,
hidden_size
,
output
):
"""Validate that the Keras object can be invoked."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment