Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
55b5100e
Commit
55b5100e
authored
Aug 17, 2020
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 327149314
parent
92ea3959
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
397 additions
and
2 deletions
+397
-2
official/nlp/modeling/layers/README.md
official/nlp/modeling/layers/README.md
+8
-2
official/nlp/modeling/layers/__init__.py
official/nlp/modeling/layers/__init__.py
+1
-0
official/nlp/modeling/layers/mat_mul_with_margin.py
official/nlp/modeling/layers/mat_mul_with_margin.py
+73
-0
official/nlp/modeling/layers/mat_mul_with_margin_test.py
official/nlp/modeling/layers/mat_mul_with_margin_test.py
+57
-0
official/nlp/modeling/models/README.md
official/nlp/modeling/models/README.md
+3
-0
official/nlp/modeling/models/__init__.py
official/nlp/modeling/models/__init__.py
+1
-0
official/nlp/modeling/models/dual_encoder.py
official/nlp/modeling/models/dual_encoder.py
+129
-0
official/nlp/modeling/models/dual_encoder_test.py
official/nlp/modeling/models/dual_encoder_test.py
+125
-0
No files found.
official/nlp/modeling/layers/README.md
View file @
55b5100e
...
@@ -11,6 +11,12 @@ assemble new layers, networks, or models.
...
@@ -11,6 +11,12 @@ assemble new layers, networks, or models.
*
[
CachedAttention
](
attention.py
)
implements an attention layer with cache
*
[
CachedAttention
](
attention.py
)
implements an attention layer with cache
used for auto-agressive decoding.
used for auto-agressive decoding.
*
[
MatMulWithMargin
](
mat_mul_with_margin.py
)
implements a matrix
multiplication with margin layer used for training retrieval / ranking
tasks, as described in
[
"Improving Multilingual Sentence Embedding using
Bi-directional Dual Encoder with Additive Margin
Softmax"
](
https://www.ijcai.org/Proceedings/2019/0746.pdf
)
.
*
[
MultiChannelAttention
](
multi_channel_attention.py
)
implements an variant of
*
[
MultiChannelAttention
](
multi_channel_attention.py
)
implements an variant of
multi-head attention which can be used to merge multiple streams for
multi-head attention which can be used to merge multiple streams for
cross-attentions.
cross-attentions.
...
@@ -24,8 +30,8 @@ assemble new layers, networks, or models.
...
@@ -24,8 +30,8 @@ assemble new layers, networks, or models.
[
"Attention Is All You Need"
](
https://arxiv.org/abs/1706.03762
)
.
[
"Attention Is All You Need"
](
https://arxiv.org/abs/1706.03762
)
.
*
[
TransformerDecoderLayer
](
transformer.py
)
TransformerDecoderLayer is made up
*
[
TransformerDecoderLayer
](
transformer.py
)
TransformerDecoderLayer is made up
of self multi-head attention, cross multi-head attention and
of self multi-head attention, cross multi-head attention and
feedforward
feedforward
network.
network.
*
[
ReZeroTransformer
](
rezero_transformer.py
)
implements Transformer with
*
[
ReZeroTransformer
](
rezero_transformer.py
)
implements Transformer with
ReZero described in
ReZero described in
...
...
official/nlp/modeling/layers/__init__.py
View file @
55b5100e
...
@@ -20,6 +20,7 @@ from official.nlp.modeling.layers.dense_einsum import DenseEinsum
...
@@ -20,6 +20,7 @@ from official.nlp.modeling.layers.dense_einsum import DenseEinsum
from
official.nlp.modeling.layers.gated_feedforward
import
GatedFeedforward
from
official.nlp.modeling.layers.gated_feedforward
import
GatedFeedforward
from
official.nlp.modeling.layers.masked_lm
import
MaskedLM
from
official.nlp.modeling.layers.masked_lm
import
MaskedLM
from
official.nlp.modeling.layers.masked_softmax
import
MaskedSoftmax
from
official.nlp.modeling.layers.masked_softmax
import
MaskedSoftmax
from
official.nlp.modeling.layers.mat_mul_with_margin
import
MatMulWithMargin
from
official.nlp.modeling.layers.multi_channel_attention
import
*
from
official.nlp.modeling.layers.multi_channel_attention
import
*
from
official.nlp.modeling.layers.on_device_embedding
import
OnDeviceEmbedding
from
official.nlp.modeling.layers.on_device_embedding
import
OnDeviceEmbedding
from
official.nlp.modeling.layers.position_embedding
import
PositionEmbedding
from
official.nlp.modeling.layers.position_embedding
import
PositionEmbedding
...
...
official/nlp/modeling/layers/mat_mul_with_margin.py
0 → 100644
View file @
55b5100e
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Dot product with margin layer."""
# pylint: disable=g-classes-have-attributes
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
from
typing
import
Tuple
# Import libraries
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
class
MatMulWithMargin
(
tf
.
keras
.
layers
.
Layer
):
"""This layer computs a dot product matrix given two encoded inputs.
Arguments:
logit_scale: The scaling factor of dot products when doing training.
logit_margin: The margin value between the positive and negative examples
when doing training.
"""
def
__init__
(
self
,
logit_scale
=
1.0
,
logit_margin
=
0.0
,
**
kwargs
):
super
(
MatMulWithMargin
,
self
).
__init__
(
**
kwargs
)
self
.
logit_scale
=
logit_scale
self
.
logit_margin
=
logit_margin
def
call
(
self
,
left_encoded
:
tf
.
Tensor
,
right_encoded
:
tf
.
Tensor
)
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]:
batch_size
=
tf_utils
.
get_shape_list
(
left_encoded
,
name
=
'sequence_output_tensor'
)[
0
]
# Left -> Right dot product.
left_dot_products
=
tf
.
matmul
(
left_encoded
,
right_encoded
,
transpose_b
=
True
)
self
.
left_logits
=
self
.
logit_scale
*
(
left_dot_products
-
self
.
logit_margin
*
tf
.
eye
(
batch_size
))
# Right -> Left dot product.
self
.
right_logits
=
tf
.
transpose
(
self
.
left_logits
)
return
(
self
.
left_logits
,
self
.
right_logits
)
def
get_config
(
self
):
config
=
{
'logit_scale'
:
self
.
logit_scale
,
'logit_margin'
:
self
.
logit_margin
}
config
.
update
(
super
(
MatMulWithMargin
,
self
).
get_config
())
return
config
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
official/nlp/modeling/layers/mat_mul_with_margin_test.py
0 → 100644
View file @
55b5100e
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for mat_mul_with_margin layer."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
# Import libraries
import
tensorflow
as
tf
from
tensorflow.python.keras
import
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
from
official.nlp.modeling.layers
import
mat_mul_with_margin
class
MatMulWithMarginTest
(
keras_parameterized
.
TestCase
):
def
test_layer_invocation
(
self
):
"""Validate that the Keras object can be created and invoked."""
input_width
=
512
test_layer
=
mat_mul_with_margin
.
MatMulWithMargin
()
# Create a 2-dimensional input (the first dimension is implicit).
left_encoded
=
tf
.
keras
.
Input
(
shape
=
(
input_width
,),
dtype
=
tf
.
float32
)
right_encoded
=
tf
.
keras
.
Input
(
shape
=
(
input_width
,),
dtype
=
tf
.
float32
)
left_logits
,
right_logits
=
test_layer
(
left_encoded
,
right_encoded
)
# Validate that the outputs are of the expected shape.
expected_output_shape
=
[
None
,
None
]
self
.
assertEqual
(
expected_output_shape
,
left_logits
.
shape
.
as_list
())
self
.
assertEqual
(
expected_output_shape
,
right_logits
.
shape
.
as_list
())
def
test_serialize_deserialize
(
self
):
# Create a layer object that sets all of its config options.
layer
=
mat_mul_with_margin
.
MatMulWithMargin
()
# Create another layer object from the first object's config.
new_layer
=
mat_mul_with_margin
.
MatMulWithMargin
.
from_config
(
layer
.
get_config
())
# If the serialization was successful, the new config should match the old.
self
.
assertAllEqual
(
layer
.
get_config
(),
new_layer
.
get_config
())
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/nlp/modeling/models/README.md
View file @
55b5100e
...
@@ -20,3 +20,6 @@ index and an end token index), suitable for SQuAD-style tasks.
...
@@ -20,3 +20,6 @@ index and an end token index), suitable for SQuAD-style tasks.
*
[
`BertPretrainer`
](
bert_pretrainer.py
)
implements a masked LM and a
*
[
`BertPretrainer`
](
bert_pretrainer.py
)
implements a masked LM and a
classification head using the Masked LM and Classification networks,
classification head using the Masked LM and Classification networks,
respectively.
respectively.
*
[
`DualEncoder`
](
dual_encoder.py
)
implements a dual encoder model, suitbale for
retrieval tasks.
official/nlp/modeling/models/__init__.py
View file @
55b5100e
...
@@ -17,4 +17,5 @@ from official.nlp.modeling.models.bert_classifier import BertClassifier
...
@@ -17,4 +17,5 @@ from official.nlp.modeling.models.bert_classifier import BertClassifier
from
official.nlp.modeling.models.bert_pretrainer
import
*
from
official.nlp.modeling.models.bert_pretrainer
import
*
from
official.nlp.modeling.models.bert_span_labeler
import
BertSpanLabeler
from
official.nlp.modeling.models.bert_span_labeler
import
BertSpanLabeler
from
official.nlp.modeling.models.bert_token_classifier
import
BertTokenClassifier
from
official.nlp.modeling.models.bert_token_classifier
import
BertTokenClassifier
from
official.nlp.modeling.models.dual_encoder
import
DualEncoder
from
official.nlp.modeling.models.electra_pretrainer
import
ElectraPretrainer
from
official.nlp.modeling.models.electra_pretrainer
import
ElectraPretrainer
official/nlp/modeling/models/dual_encoder.py
0 → 100644
View file @
55b5100e
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trainer network for dual encoder style models."""
# pylint: disable=g-classes-have-attributes
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
# Import libraries
import
tensorflow
as
tf
from
official.nlp.modeling
import
layers
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
class
DualEncoder
(
tf
.
keras
.
Model
):
"""A dual encoder model based on a transformer-based encoder.
This is an implementation of the dual encoder network structure based on the
transfomer stack, as described in ["Language-agnostic BERT Sentence
Embedding"](https://arxiv.org/abs/2007.01852)
The DualEncoder allows a user to pass in a transformer stack, and build a dual
encoder model based on the transformer stack.
Arguments:
network: A transformer network which should output an encoding output.
max_seq_length: The maximum allowed sequence length for transformer.
normalize: If set to True, normalize the encoding produced by transfomer.
logit_scale: The scaling factor of dot products when doing training.
logit_margin: The margin between positive and negative when doing training.
output: The output style for this network. Can be either 'logits' or
'predictions'. If set to 'predictions', it will output the embedding
producted by transformer network.
"""
def
__init__
(
self
,
network
:
tf
.
keras
.
Model
,
max_seq_length
:
int
=
32
,
normalize
:
bool
=
True
,
logit_scale
:
float
=
1.0
,
logit_margin
:
float
=
0.0
,
output
:
str
=
'logits'
,
**
kwargs
)
->
None
:
self
.
_self_setattr_tracking
=
False
self
.
_config
=
{
'network'
:
network
,
'max_seq_length'
:
max_seq_length
,
'normalize'
:
normalize
,
'logit_scale'
:
logit_scale
,
'logit_margin'
:
logit_margin
,
'output'
:
output
,
}
self
.
network
=
network
left_word_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'left_word_ids'
)
left_mask
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'left_mask'
)
left_type_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'left_type_ids'
)
left_inputs
=
[
left_word_ids
,
left_mask
,
left_type_ids
]
_
,
left_encoded
=
network
(
left_inputs
)
if
normalize
:
left_encoded
=
tf
.
keras
.
layers
.
Lambda
(
lambda
x
:
tf
.
nn
.
l2_normalize
(
x
,
axis
=
1
))(
left_encoded
)
if
output
==
'logits'
:
right_word_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'right_word_ids'
)
right_mask
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'right_mask'
)
right_type_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'right_type_ids'
)
right_inputs
=
[
right_word_ids
,
right_mask
,
right_type_ids
]
_
,
right_encoded
=
network
(
right_inputs
)
if
normalize
:
right_encoded
=
tf
.
keras
.
layers
.
Lambda
(
lambda
x
:
tf
.
nn
.
l2_normalize
(
x
,
axis
=
1
))(
right_encoded
)
dot_products
=
layers
.
MatMulWithMargin
(
logit_scale
=
logit_scale
,
logit_margin
=
logit_margin
,
name
=
'dot_product'
)
inputs
=
[
left_word_ids
,
left_mask
,
left_type_ids
,
right_word_ids
,
right_mask
,
right_type_ids
]
left_logits
,
right_logits
=
dot_products
(
left_encoded
,
right_encoded
)
outputs
=
[
left_logits
,
right_logits
]
elif
output
==
'predictions'
:
inputs
=
[
left_word_ids
,
left_mask
,
left_type_ids
]
outputs
=
left_encoded
else
:
raise
ValueError
(
'output type %s is not supported'
%
output
)
super
(
DualEncoder
,
self
).
__init__
(
inputs
=
inputs
,
outputs
=
outputs
,
**
kwargs
)
def
get_config
(
self
):
return
self
.
_config
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
@
property
def
checkpoint_items
(
self
):
"""Returns a dictionary of items to be additionally checkpointed."""
items
=
dict
(
encoder
=
self
.
network
)
return
items
official/nlp/modeling/models/dual_encoder_test.py
0 → 100644
View file @
55b5100e
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for dual encoder network."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
# Import libraries
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
tensorflow.python.keras
import
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
from
official.nlp.modeling
import
networks
from
official.nlp.modeling.models
import
dual_encoder
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@
keras_parameterized
.
run_all_keras_modes
class
DualEncoderTest
(
keras_parameterized
.
TestCase
):
@
parameterized
.
parameters
((
192
,
'logits'
),
(
768
,
'predictions'
))
def
test_dual_encoder
(
self
,
hidden_size
,
output
):
"""Validate that the Keras object can be created."""
# Build a transformer network to use within the dual encoder model.
vocab_size
=
100
sequence_length
=
512
test_network
=
networks
.
TransformerEncoder
(
vocab_size
=
vocab_size
,
num_layers
=
2
,
hidden_size
=
hidden_size
,
sequence_length
=
sequence_length
)
# Create a dual encoder model with the created network.
dual_encoder_model
=
dual_encoder
.
DualEncoder
(
test_network
,
max_seq_length
=
sequence_length
,
output
=
output
)
# Create a set of 2-dimensional inputs (the first dimension is implicit).
left_word_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
left_mask
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
left_type_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
right_word_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
right_mask
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
right_type_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
if
output
==
'logits'
:
outputs
=
dual_encoder_model
([
left_word_ids
,
left_mask
,
left_type_ids
,
right_word_ids
,
right_mask
,
right_type_ids
])
left_encoded
,
_
=
outputs
elif
output
==
'predictions'
:
left_encoded
=
dual_encoder_model
([
left_word_ids
,
left_mask
,
left_type_ids
])
# Validate that the outputs are of the expected shape.
expected_encoding_shape
=
[
None
,
768
]
self
.
assertAllEqual
(
expected_encoding_shape
,
left_encoded
.
shape
.
as_list
())
@
parameterized
.
parameters
((
192
,
'logits'
),
(
768
,
'predictions'
))
def
test_dual_encoder_tensor_call
(
self
,
hidden_size
,
output
):
"""Validate that the Keras object can be invoked."""
# Build a transformer network to use within the dual encoder model. (Here,
# we use # a short sequence_length for convenience.)
sequence_length
=
2
test_network
=
networks
.
TransformerEncoder
(
vocab_size
=
100
,
num_layers
=
2
,
sequence_length
=
sequence_length
)
# Create a dual encoder model with the created network.
dual_encoder_model
=
dual_encoder
.
DualEncoder
(
test_network
,
max_seq_length
=
sequence_length
,
output
=
output
)
# Create a set of 2-dimensional data tensors to feed into the model.
word_ids
=
tf
.
constant
([[
1
,
1
],
[
2
,
2
]],
dtype
=
tf
.
int32
)
mask
=
tf
.
constant
([[
1
,
1
],
[
1
,
0
]],
dtype
=
tf
.
int32
)
type_ids
=
tf
.
constant
([[
1
,
1
],
[
2
,
2
]],
dtype
=
tf
.
int32
)
# Invoke the model model on the tensors. In Eager mode, this does the
# actual calculation. (We can't validate the outputs, since the network is
# too complex: this simply ensures we're not hitting runtime errors.)
if
output
==
'logits'
:
_
=
dual_encoder_model
(
[
word_ids
,
mask
,
type_ids
,
word_ids
,
mask
,
type_ids
])
elif
output
==
'predictions'
:
_
=
dual_encoder_model
([
word_ids
,
mask
,
type_ids
])
def
test_serialize_deserialize
(
self
):
"""Validate that the dual encoder model can be serialized / deserialized."""
# Build a transformer network to use within the dual encoder model. (Here,
# we use a short sequence_length for convenience.)
sequence_length
=
32
test_network
=
networks
.
TransformerEncoder
(
vocab_size
=
100
,
num_layers
=
2
,
sequence_length
=
sequence_length
)
# Create a dual encoder model with the created network. (Note that all the
# args are different, so we can catch any serialization mismatches.)
dual_encoder_model
=
dual_encoder
.
DualEncoder
(
test_network
,
max_seq_length
=
sequence_length
,
output
=
'predictions'
)
# Create another dual encoder model via serialization and deserialization.
config
=
dual_encoder_model
.
get_config
()
new_dual_encoder
=
dual_encoder
.
DualEncoder
.
from_config
(
config
)
# Validate that the config can be forced to JSON.
_
=
new_dual_encoder
.
to_json
()
# If the serialization was successful, the new config should match the old.
self
.
assertAllEqual
(
dual_encoder_model
.
get_config
(),
new_dual_encoder
.
get_config
())
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment