Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
0cceabfc
Unverified
Commit
0cceabfc
authored
Aug 03, 2020
by
Yiming Shi
Committed by
GitHub
Aug 03, 2020
Browse files
Merge branch 'master' into move_to_keraslayers_fasterrcnn_fpn_keras_feature_extractor
parents
17821c0d
39ee0ac9
Changes
339
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
643 additions
and
393 deletions
+643
-393
official/nlp/modeling/layers/masked_lm_test.py
official/nlp/modeling/layers/masked_lm_test.py
+1
-8
official/nlp/modeling/layers/multi_channel_attention.py
official/nlp/modeling/layers/multi_channel_attention.py
+73
-39
official/nlp/modeling/layers/multi_channel_attention_test.py
official/nlp/modeling/layers/multi_channel_attention_test.py
+8
-3
official/nlp/modeling/layers/on_device_embedding.py
official/nlp/modeling/layers/on_device_embedding.py
+8
-0
official/nlp/modeling/layers/on_device_embedding_test.py
official/nlp/modeling/layers/on_device_embedding_test.py
+20
-0
official/nlp/modeling/layers/position_embedding.py
official/nlp/modeling/layers/position_embedding.py
+0
-1
official/nlp/modeling/layers/rezero_transformer.py
official/nlp/modeling/layers/rezero_transformer.py
+22
-31
official/nlp/modeling/layers/talking_heads_attention.py
official/nlp/modeling/layers/talking_heads_attention.py
+7
-7
official/nlp/modeling/layers/talking_heads_attention_test.py
official/nlp/modeling/layers/talking_heads_attention_test.py
+10
-8
official/nlp/modeling/layers/transformer.py
official/nlp/modeling/layers/transformer.py
+320
-41
official/nlp/modeling/layers/transformer_scaffold.py
official/nlp/modeling/layers/transformer_scaffold.py
+2
-3
official/nlp/modeling/layers/transformer_scaffold_test.py
official/nlp/modeling/layers/transformer_scaffold_test.py
+2
-2
official/nlp/modeling/layers/transformer_test.py
official/nlp/modeling/layers/transformer_test.py
+110
-1
official/nlp/modeling/losses/README.md
official/nlp/modeling/losses/README.md
+0
-3
official/nlp/modeling/losses/__init__.py
official/nlp/modeling/losses/__init__.py
+0
-1
official/nlp/modeling/losses/weighted_sparse_categorical_crossentropy.py
...deling/losses/weighted_sparse_categorical_crossentropy.py
+9
-39
official/nlp/modeling/losses/weighted_sparse_categorical_crossentropy_test.py
...g/losses/weighted_sparse_categorical_crossentropy_test.py
+12
-184
official/nlp/modeling/models/README.md
official/nlp/modeling/models/README.md
+2
-2
official/nlp/modeling/models/__init__.py
official/nlp/modeling/models/__init__.py
+1
-0
official/nlp/modeling/models/bert_classifier.py
official/nlp/modeling/models/bert_classifier.py
+36
-20
No files found.
Too many changes to show.
To preserve performance only
339 of 339+
files are displayed.
Plain diff
Email patch
official/nlp/modeling/layers/masked_lm_test.py
View file @
0cceabfc
...
@@ -34,7 +34,6 @@ class MaskedLMTest(keras_parameterized.TestCase):
...
@@ -34,7 +34,6 @@ class MaskedLMTest(keras_parameterized.TestCase):
def
create_layer
(
self
,
def
create_layer
(
self
,
vocab_size
,
vocab_size
,
sequence_length
,
hidden_size
,
hidden_size
,
output
=
'predictions'
,
output
=
'predictions'
,
xformer_stack
=
None
):
xformer_stack
=
None
):
...
@@ -44,7 +43,6 @@ class MaskedLMTest(keras_parameterized.TestCase):
...
@@ -44,7 +43,6 @@ class MaskedLMTest(keras_parameterized.TestCase):
xformer_stack
=
transformer_encoder
.
TransformerEncoder
(
xformer_stack
=
transformer_encoder
.
TransformerEncoder
(
vocab_size
=
vocab_size
,
vocab_size
=
vocab_size
,
num_layers
=
1
,
num_layers
=
1
,
sequence_length
=
sequence_length
,
hidden_size
=
hidden_size
,
hidden_size
=
hidden_size
,
num_attention_heads
=
4
,
num_attention_heads
=
4
,
)
)
...
@@ -62,7 +60,6 @@ class MaskedLMTest(keras_parameterized.TestCase):
...
@@ -62,7 +60,6 @@ class MaskedLMTest(keras_parameterized.TestCase):
num_predictions
=
21
num_predictions
=
21
test_layer
=
self
.
create_layer
(
test_layer
=
self
.
create_layer
(
vocab_size
=
vocab_size
,
vocab_size
=
vocab_size
,
sequence_length
=
sequence_length
,
hidden_size
=
hidden_size
)
hidden_size
=
hidden_size
)
# Make sure that the output tensor of the masked LM is the right shape.
# Make sure that the output tensor of the masked LM is the right shape.
...
@@ -81,19 +78,16 @@ class MaskedLMTest(keras_parameterized.TestCase):
...
@@ -81,19 +78,16 @@ class MaskedLMTest(keras_parameterized.TestCase):
xformer_stack
=
transformer_encoder
.
TransformerEncoder
(
xformer_stack
=
transformer_encoder
.
TransformerEncoder
(
vocab_size
=
vocab_size
,
vocab_size
=
vocab_size
,
num_layers
=
1
,
num_layers
=
1
,
sequence_length
=
sequence_length
,
hidden_size
=
hidden_size
,
hidden_size
=
hidden_size
,
num_attention_heads
=
4
,
num_attention_heads
=
4
,
)
)
test_layer
=
self
.
create_layer
(
test_layer
=
self
.
create_layer
(
vocab_size
=
vocab_size
,
vocab_size
=
vocab_size
,
sequence_length
=
sequence_length
,
hidden_size
=
hidden_size
,
hidden_size
=
hidden_size
,
xformer_stack
=
xformer_stack
,
xformer_stack
=
xformer_stack
,
output
=
'predictions'
)
output
=
'predictions'
)
logit_layer
=
self
.
create_layer
(
logit_layer
=
self
.
create_layer
(
vocab_size
=
vocab_size
,
vocab_size
=
vocab_size
,
sequence_length
=
sequence_length
,
hidden_size
=
hidden_size
,
hidden_size
=
hidden_size
,
xformer_stack
=
xformer_stack
,
xformer_stack
=
xformer_stack
,
output
=
'logits'
)
output
=
'logits'
)
...
@@ -134,7 +128,6 @@ class MaskedLMTest(keras_parameterized.TestCase):
...
@@ -134,7 +128,6 @@ class MaskedLMTest(keras_parameterized.TestCase):
num_predictions
=
21
num_predictions
=
21
test_layer
=
self
.
create_layer
(
test_layer
=
self
.
create_layer
(
vocab_size
=
vocab_size
,
vocab_size
=
vocab_size
,
sequence_length
=
sequence_length
,
hidden_size
=
hidden_size
)
hidden_size
=
hidden_size
)
# Create a model from the masked LM layer.
# Create a model from the masked LM layer.
...
@@ -155,7 +148,7 @@ class MaskedLMTest(keras_parameterized.TestCase):
...
@@ -155,7 +148,7 @@ class MaskedLMTest(keras_parameterized.TestCase):
def
test_unknown_output_type_fails
(
self
):
def
test_unknown_output_type_fails
(
self
):
with
self
.
assertRaisesRegex
(
ValueError
,
'Unknown `output` value "bad".*'
):
with
self
.
assertRaisesRegex
(
ValueError
,
'Unknown `output` value "bad".*'
):
_
=
self
.
create_layer
(
_
=
self
.
create_layer
(
vocab_size
=
8
,
sequence_length
=
8
,
hidden_size
=
8
,
output
=
'bad'
)
vocab_size
=
8
,
hidden_size
=
8
,
output
=
'bad'
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
official/nlp/
nhnet
/multi_channel_attention.py
→
official/nlp/
modeling/layers
/multi_channel_attention.py
View file @
0cceabfc
...
@@ -13,7 +13,8 @@
...
@@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Multi-channel decoder."""
"""Multi-channel Attention."""
# pylint: disable=g-classes-have-attributes
from
__future__
import
absolute_import
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
division
...
@@ -24,11 +25,24 @@ import math
...
@@ -24,11 +25,24 @@ import math
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.modeling
import
tf_utils
from
official.nlp.modeling
import
layers
from
official.nlp.modeling.layers
import
attention
from
official.nlp.modeling.layers
import
masked_softmax
class
DocAttention
(
tf
.
keras
.
layers
.
Layer
):
class
VotingAttention
(
tf
.
keras
.
layers
.
Layer
):
"""Documents Attention layer."""
"""Voting Attention layer.
Arguments:
num_heads: the number of attention heads.
head_size: per-head hidden size.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
bias_regularizer: Regularizer for dense layer biases.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
"""
def
__init__
(
self
,
def
__init__
(
self
,
num_heads
,
num_heads
,
...
@@ -41,7 +55,7 @@ class DocAttention(tf.keras.layers.Layer):
...
@@ -41,7 +55,7 @@ class DocAttention(tf.keras.layers.Layer):
kernel_constraint
=
None
,
kernel_constraint
=
None
,
bias_constraint
=
None
,
bias_constraint
=
None
,
**
kwargs
):
**
kwargs
):
super
(
Doc
Attention
,
self
).
__init__
(
**
kwargs
)
super
(
Voting
Attention
,
self
).
__init__
(
**
kwargs
)
self
.
_num_heads
=
num_heads
self
.
_num_heads
=
num_heads
self
.
_head_size
=
head_size
self
.
_head_size
=
head_size
self
.
_kernel_initializer
=
tf
.
keras
.
initializers
.
get
(
kernel_initializer
)
self
.
_kernel_initializer
=
tf
.
keras
.
initializers
.
get
(
kernel_initializer
)
...
@@ -52,29 +66,27 @@ class DocAttention(tf.keras.layers.Layer):
...
@@ -52,29 +66,27 @@ class DocAttention(tf.keras.layers.Layer):
self
.
_bias_constraint
=
tf
.
keras
.
constraints
.
get
(
bias_constraint
)
self
.
_bias_constraint
=
tf
.
keras
.
constraints
.
get
(
bias_constraint
)
def
build
(
self
,
unused_input_shapes
):
def
build
(
self
,
unused_input_shapes
):
self
.
_query_dense
=
layers
.
DenseEinsum
(
common_kwargs
=
dict
(
output_shape
=
(
self
.
_num_heads
,
self
.
_head_size
),
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
bias_constraint
=
self
.
_bias_constraint
)
dtype
=
self
.
dtype
,
self
.
_query_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
name
=
"encdocatt_query"
)
"BAE,ENH->BANH"
,
self
.
_key_dense
=
layers
.
DenseEinsum
(
output_shape
=
(
None
,
self
.
_num_heads
,
self
.
_head_size
),
output_shape
=
(
self
.
_num_heads
,
self
.
_head_size
),
bias_axes
=
"NH"
,
kernel_initializer
=
self
.
_kernel_initializer
,
name
=
"query"
,
bias_initializer
=
self
.
_bias_initializer
,
**
common_kwargs
)
kernel_regularizer
=
self
.
_kernel_regularizer
,
self
.
_key_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
bias_regularizer
=
self
.
_bias_regularizer
,
"BAE,ENH->BANH"
,
activity_regularizer
=
self
.
_activity_regularizer
,
output_shape
=
(
None
,
self
.
_num_heads
,
self
.
_head_size
),
kernel_constraint
=
self
.
_kernel_constraint
,
bias_axes
=
"NH"
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"key"
,
dtype
=
self
.
dtype
,
**
common_kwargs
)
name
=
"encdocatt_key"
)
super
(
VotingAttention
,
self
).
build
(
unused_input_shapes
)
super
(
DocAttention
,
self
).
build
(
unused_input_shapes
)
def
call
(
self
,
encoder_outputs
,
doc_attention_mask
):
def
call
(
self
,
encoder_outputs
,
doc_attention_mask
):
num_docs
=
tf_utils
.
get_shape_list
(
encoder_outputs
,
expected_rank
=
[
4
])[
1
]
num_docs
=
tf_utils
.
get_shape_list
(
encoder_outputs
,
expected_rank
=
[
4
])[
1
]
...
@@ -95,33 +107,55 @@ class DocAttention(tf.keras.layers.Layer):
...
@@ -95,33 +107,55 @@ class DocAttention(tf.keras.layers.Layer):
return
tf
.
nn
.
softmax
(
doc_attention_probs
+
infadder
)
return
tf
.
nn
.
softmax
(
doc_attention_probs
+
infadder
)
class
MultiChannelAttention
(
layers
.
MultiHeadAttention
):
class
MultiChannelAttention
(
attention
.
MultiHeadAttention
):
"""Multi-channel Attention layer."""
"""Multi-channel Attention layer.
def
build
(
self
,
input_shape
):
Introduced in, [Generating Representative Headlines for News Stories
super
(
MultiChannelAttention
,
self
).
build
(
input_shape
)
](https://arxiv.org/abs/2001.09386). Expects multiple cross-attention
self
.
_masked_softmax
=
layers
.
MaskedSoftmax
(
mask_expansion_axes
=
[
2
])
target sequences.
def
call
(
self
,
inputs
,
attention_mask
=
None
):
Call args:
from_tensor
=
inputs
[
0
]
query: Query `Tensor` of shape `[B, T, dim]`.
to_tensor
=
inputs
[
1
]
value: Value `Tensor` of shape `[B, A, S, dim]`, where A denotes the
doc_attention_probs
=
inputs
[
2
]
context_attention_weights: Context weights of shape `[B, N, T, A]`, where N
is the number of attention heads. Combines multi-channel sources
context tensors according to the distribution among channels.
key: Optional key `Tensor` of shape `[B, A, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention
to certain positions.
"""
def
build_attention
(
self
,
rank
):
super
(
MultiChannelAttention
,
self
).
build_attention
(
rank
)
self
.
_masked_softmax
=
masked_softmax
.
MaskedSoftmax
(
mask_expansion_axes
=
[
2
])
def
call
(
self
,
query
,
value
,
key
=
None
,
context_attention_weights
=
None
,
attention_mask
=
None
):
if
not
self
.
_built_from_signature
:
self
.
_build_from_signature
(
query
,
value
,
key
=
key
)
if
key
is
None
:
key
=
value
# Scalar dimensions referenced here:
# Scalar dimensions referenced here:
# B = batch size (number of stories)
# B = batch size (number of stories)
# A = num_docs (number of docs)
# A = num_docs (number of docs)
# F =
`from_tensor`
sequence length
# F =
target
sequence length
# T =
`to_tensor`
sequence length
# T =
source
sequence length
# N = `num_attention_heads`
# N = `num_attention_heads`
# H = `size_per_head`
# H = `size_per_head`
# `query_tensor` = [B, F, N ,H]
# `query_tensor` = [B, F, N ,H]
query_tensor
=
self
.
_query_dense
(
from_tensor
)
query_tensor
=
self
.
_query_dense
(
query
)
# `key_tensor` = [B, A, T, N, H]
# `key_tensor` = [B, A, T, N, H]
key_tensor
=
self
.
_key_dense
(
to_tensor
)
key_tensor
=
self
.
_key_dense
(
key
)
# `value_tensor` = [B, A, T, N, H]
# `value_tensor` = [B, A, T, N, H]
value_tensor
=
self
.
_value_dense
(
to_tensor
)
value_tensor
=
self
.
_value_dense
(
value
)
# Take the dot product between "query" and "key" to get the raw
# Take the dot product between "query" and "key" to get the raw
# attention scores.
# attention scores.
...
@@ -140,7 +174,7 @@ class MultiChannelAttention(layers.MultiHeadAttention):
...
@@ -140,7 +174,7 @@ class MultiChannelAttention(layers.MultiHeadAttention):
# `context_layer` = [B, F, N, H]
# `context_layer` = [B, F, N, H]
context_layer
=
tf
.
einsum
(
"BANFT,BATNH->BAFNH"
,
attention_probs
,
context_layer
=
tf
.
einsum
(
"BANFT,BATNH->BAFNH"
,
attention_probs
,
value_tensor
)
value_tensor
)
attention_output
=
tf
.
einsum
(
"BNFA,BAFNH->BFNH"
,
doc
_attention_
prob
s
,
attention_output
=
tf
.
einsum
(
"BNFA,BAFNH->BFNH"
,
context
_attention_
weight
s
,
context_layer
)
context_layer
)
attention_output
=
self
.
_output_dense
(
attention_output
)
attention_output
=
self
.
_output_dense
(
attention_output
)
return
attention_output
return
attention_output
official/nlp/
nhnet
/multi_channel_attention_test.py
→
official/nlp/
modeling/layers
/multi_channel_attention_test.py
View file @
0cceabfc
...
@@ -22,14 +22,15 @@ from __future__ import print_function
...
@@ -22,14 +22,15 @@ from __future__ import print_function
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.nlp.
nhnet
import
multi_channel_attention
from
official.nlp.
modeling.layers
import
multi_channel_attention
class
MultiChannelAttentionTest
(
tf
.
test
.
TestCase
):
class
MultiChannelAttentionTest
(
tf
.
test
.
TestCase
):
def
test_doc_attention
(
self
):
def
test_doc_attention
(
self
):
num_heads
=
2
num_heads
=
2
doc_attention
=
multi_channel_attention
.
DocAttention
(
num_heads
,
head_size
=
8
)
doc_attention
=
multi_channel_attention
.
VotingAttention
(
num_heads
,
head_size
=
8
)
num_docs
=
3
num_docs
=
3
inputs
=
np
.
zeros
((
2
,
num_docs
,
10
,
16
),
dtype
=
np
.
float32
)
inputs
=
np
.
zeros
((
2
,
num_docs
,
10
,
16
),
dtype
=
np
.
float32
)
doc_mask
=
np
.
zeros
((
2
,
num_docs
),
dtype
=
np
.
float32
)
doc_mask
=
np
.
zeros
((
2
,
num_docs
),
dtype
=
np
.
float32
)
...
@@ -47,7 +48,11 @@ class MultiChannelAttentionTest(tf.test.TestCase):
...
@@ -47,7 +48,11 @@ class MultiChannelAttentionTest(tf.test.TestCase):
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
3
,
num_docs
,
4
,
2
))
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
3
,
num_docs
,
4
,
2
))
doc_probs
=
np
.
random
.
randint
(
doc_probs
=
np
.
random
.
randint
(
2
,
size
=
(
3
,
num_heads
,
4
,
num_docs
)).
astype
(
float
)
2
,
size
=
(
3
,
num_heads
,
4
,
num_docs
)).
astype
(
float
)
outputs
=
attention_layer
([
from_data
,
to_data
,
doc_probs
],
mask_data
)
outputs
=
attention_layer
(
query
=
from_data
,
value
=
to_data
,
context_attention_weights
=
doc_probs
,
attention_mask
=
mask_data
)
self
.
assertEqual
(
outputs
.
shape
,
(
3
,
4
,
8
))
self
.
assertEqual
(
outputs
.
shape
,
(
3
,
4
,
8
))
...
...
official/nlp/modeling/layers/on_device_embedding.py
View file @
0cceabfc
...
@@ -38,6 +38,9 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
...
@@ -38,6 +38,9 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
lookup. Defaults to False (that is, using tf.gather). Setting this option
lookup. Defaults to False (that is, using tf.gather). Setting this option
to True may improve performance, especially on small vocabulary sizes, but
to True may improve performance, especially on small vocabulary sizes, but
will generally require more memory.
will generally require more memory.
use_scale: Whether to scale the output embeddings. Defaults to False (that
is, not to scale). Setting this option to True will let values in output
embeddings multiplied by self._embedding_width ** 0.5.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -45,6 +48,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
...
@@ -45,6 +48,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
embedding_width
,
embedding_width
,
initializer
=
"glorot_uniform"
,
initializer
=
"glorot_uniform"
,
use_one_hot
=
False
,
use_one_hot
=
False
,
use_scale
=
False
,
**
kwargs
):
**
kwargs
):
super
(
OnDeviceEmbedding
,
self
).
__init__
(
**
kwargs
)
super
(
OnDeviceEmbedding
,
self
).
__init__
(
**
kwargs
)
...
@@ -52,6 +56,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
...
@@ -52,6 +56,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
self
.
_embedding_width
=
embedding_width
self
.
_embedding_width
=
embedding_width
self
.
_initializer
=
initializer
self
.
_initializer
=
initializer
self
.
_use_one_hot
=
use_one_hot
self
.
_use_one_hot
=
use_one_hot
self
.
_use_scale
=
use_scale
def
get_config
(
self
):
def
get_config
(
self
):
config
=
{
config
=
{
...
@@ -59,6 +64,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
...
@@ -59,6 +64,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
"embedding_width"
:
self
.
_embedding_width
,
"embedding_width"
:
self
.
_embedding_width
,
"initializer"
:
self
.
_initializer
,
"initializer"
:
self
.
_initializer
,
"use_one_hot"
:
self
.
_use_one_hot
,
"use_one_hot"
:
self
.
_use_one_hot
,
"use_scale"
:
self
.
_use_scale
,
}
}
base_config
=
super
(
OnDeviceEmbedding
,
self
).
get_config
()
base_config
=
super
(
OnDeviceEmbedding
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
...
@@ -85,4 +91,6 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
...
@@ -85,4 +91,6 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
# Work around b/142213824: prefer concat to shape over a Python list.
# Work around b/142213824: prefer concat to shape over a Python list.
tf
.
concat
([
tf
.
shape
(
inputs
),
[
self
.
_embedding_width
]],
axis
=
0
))
tf
.
concat
([
tf
.
shape
(
inputs
),
[
self
.
_embedding_width
]],
axis
=
0
))
embeddings
.
set_shape
(
inputs
.
shape
.
as_list
()
+
[
self
.
_embedding_width
])
embeddings
.
set_shape
(
inputs
.
shape
.
as_list
()
+
[
self
.
_embedding_width
])
if
self
.
_use_scale
:
embeddings
*=
self
.
_embedding_width
**
0.5
return
embeddings
return
embeddings
official/nlp/modeling/layers/on_device_embedding_test.py
View file @
0cceabfc
...
@@ -193,6 +193,26 @@ class OnDeviceEmbeddingTest(keras_parameterized.TestCase):
...
@@ -193,6 +193,26 @@ class OnDeviceEmbeddingTest(keras_parameterized.TestCase):
output
=
model
.
predict
(
input_data
)
output
=
model
.
predict
(
input_data
)
self
.
assertEqual
(
tf
.
float16
,
output
.
dtype
)
self
.
assertEqual
(
tf
.
float16
,
output
.
dtype
)
def
test_use_scale_layer_invocation
(
self
):
vocab_size
=
31
embedding_width
=
27
test_layer
=
on_device_embedding
.
OnDeviceEmbedding
(
vocab_size
=
vocab_size
,
embedding_width
=
embedding_width
,
use_scale
=
True
)
# Create a 2-dimensional input (the first dimension is implicit).
sequence_length
=
23
input_tensor
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
),
dtype
=
tf
.
int32
)
output_tensor
=
test_layer
(
input_tensor
)
# Create a model from the test layer.
model
=
tf
.
keras
.
Model
(
input_tensor
,
output_tensor
)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size
=
3
input_data
=
np
.
random
.
randint
(
vocab_size
,
size
=
(
batch_size
,
sequence_length
))
output
=
model
.
predict
(
input_data
)
self
.
assertEqual
(
tf
.
float32
,
output
.
dtype
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
tf
.
test
.
main
()
official/nlp/modeling/layers/position_embedding.py
View file @
0cceabfc
...
@@ -160,7 +160,6 @@ class RelativePositionEmbedding(tf.keras.layers.Layer):
...
@@ -160,7 +160,6 @@ class RelativePositionEmbedding(tf.keras.layers.Layer):
"hidden_size"
:
self
.
_hidden_size
,
"hidden_size"
:
self
.
_hidden_size
,
"min_timescale"
:
self
.
_min_timescale
,
"min_timescale"
:
self
.
_min_timescale
,
"max_timescale"
:
self
.
_max_timescale
,
"max_timescale"
:
self
.
_max_timescale
,
"length"
:
self
.
_length
,
}
}
base_config
=
super
(
RelativePositionEmbedding
,
self
).
get_config
()
base_config
=
super
(
RelativePositionEmbedding
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
...
...
official/nlp/modeling/layers/rezero_transformer.py
View file @
0cceabfc
...
@@ -23,7 +23,6 @@ import gin
...
@@ -23,7 +23,6 @@ import gin
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.nlp.modeling.layers
import
attention
from
official.nlp.modeling.layers
import
attention
from
official.nlp.modeling.layers
import
dense_einsum
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
...
@@ -109,19 +108,20 @@ class ReZeroTransformer(tf.keras.layers.Layer):
...
@@ -109,19 +108,20 @@ class ReZeroTransformer(tf.keras.layers.Layer):
"The input size (%d) is not a multiple of the number of attention "
"The input size (%d) is not a multiple of the number of attention "
"heads (%d)"
%
(
hidden_size
,
self
.
_num_heads
))
"heads (%d)"
%
(
hidden_size
,
self
.
_num_heads
))
self
.
_attention_head_size
=
int
(
hidden_size
//
self
.
_num_heads
)
self
.
_attention_head_size
=
int
(
hidden_size
//
self
.
_num_heads
)
common_kwargs
=
dict
(
self
.
_attention_layer
=
attention
.
MultiHeadAttention
(
num_heads
=
self
.
_num_heads
,
key_size
=
self
.
_attention_head_size
,
dropout
=
self
.
_attention_dropout_rate
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
bias_constraint
=
self
.
_bias_constraint
)
name
=
"self_attention"
)
self
.
_attention_layer
=
attention
.
MultiHeadAttention
(
num_heads
=
self
.
_num_heads
,
key_size
=
self
.
_attention_head_size
,
dropout
=
self
.
_attention_dropout_rate
,
name
=
"self_attention"
,
**
common_kwargs
)
self
.
_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
self
.
_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
if
self
.
_use_layer_norm
:
if
self
.
_use_layer_norm
:
# Use float32 in layernorm for numeric stability.
# Use float32 in layernorm for numeric stability.
...
@@ -132,17 +132,12 @@ class ReZeroTransformer(tf.keras.layers.Layer):
...
@@ -132,17 +132,12 @@ class ReZeroTransformer(tf.keras.layers.Layer):
axis
=-
1
,
axis
=-
1
,
epsilon
=
1e-12
,
epsilon
=
1e-12
,
dtype
=
tf
.
float32
))
dtype
=
tf
.
float32
))
self
.
_intermediate_dense
=
dense_einsum
.
DenseEinsum
(
self
.
_intermediate_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
output_shape
=
self
.
_intermediate_size
,
"abc,cd->abd"
,
activation
=
None
,
output_shape
=
(
None
,
self
.
_intermediate_size
),
kernel_initializer
=
self
.
_kernel_initializer
,
bias_axes
=
"d"
,
bias_initializer
=
self
.
_bias_initializer
,
name
=
"intermediate"
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
**
common_kwargs
)
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"intermediate"
)
policy
=
tf
.
keras
.
mixed_precision
.
experimental
.
global_policy
()
policy
=
tf
.
keras
.
mixed_precision
.
experimental
.
global_policy
()
if
policy
.
name
==
"mixed_bfloat16"
:
if
policy
.
name
==
"mixed_bfloat16"
:
# bfloat16 causes BERT with the LAMB optimizer to not converge
# bfloat16 causes BERT with the LAMB optimizer to not converge
...
@@ -151,16 +146,12 @@ class ReZeroTransformer(tf.keras.layers.Layer):
...
@@ -151,16 +146,12 @@ class ReZeroTransformer(tf.keras.layers.Layer):
policy
=
tf
.
float32
policy
=
tf
.
float32
self
.
_intermediate_activation_layer
=
tf
.
keras
.
layers
.
Activation
(
self
.
_intermediate_activation_layer
=
tf
.
keras
.
layers
.
Activation
(
self
.
_intermediate_activation
,
dtype
=
policy
)
self
.
_intermediate_activation
,
dtype
=
policy
)
self
.
_output_dense
=
dense_einsum
.
DenseEinsum
(
self
.
_output_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
output_shape
=
hidden_size
,
"abc,cd->abd"
,
kernel_initializer
=
self
.
_kernel_initializer
,
output_shape
=
(
None
,
hidden_size
),
bias_initializer
=
self
.
_bias_initializer
,
bias_axes
=
"d"
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
name
=
"output"
,
bias_regularizer
=
self
.
_bias_regularizer
,
**
common_kwargs
)
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"output"
)
self
.
_output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
self
.
_output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
if
self
.
_use_layer_norm
:
if
self
.
_use_layer_norm
:
# Use float32 in layernorm for numeric stability.
# Use float32 in layernorm for numeric stability.
...
@@ -222,9 +213,9 @@ class ReZeroTransformer(tf.keras.layers.Layer):
...
@@ -222,9 +213,9 @@ class ReZeroTransformer(tf.keras.layers.Layer):
attention_mask
=
attention_mask
[:,
0
:
self
.
_output_range
,
:]
attention_mask
=
attention_mask
[:,
0
:
self
.
_output_range
,
:]
else
:
else
:
target_tensor
=
input_tensor
target_tensor
=
input_tensor
attention_inputs
=
[
target_tensor
,
input_tensor
]
attention_output
=
self
.
_attention_layer
(
attention_inputs
,
attention_mask
)
attention_output
=
self
.
_attention_layer
(
query
=
target_tensor
,
value
=
input_tensor
,
attention_mask
=
attention_mask
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
attention_output
=
target_tensor
+
self
.
_rezero_a
*
attention_output
attention_output
=
target_tensor
+
self
.
_rezero_a
*
attention_output
if
self
.
_use_layer_norm
:
if
self
.
_use_layer_norm
:
...
...
official/nlp/modeling/layers/talking_heads_attention.py
View file @
0cceabfc
...
@@ -58,7 +58,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
...
@@ -58,7 +58,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
bias_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
"""
"""
def
_
build_attention
(
self
,
qkv_rank
):
def
build_attention
(
self
,
qkv_rank
):
"""Builds multi-head dot-product attention computations.
"""Builds multi-head dot-product attention computations.
This function overrides base class to create additional linear projection
This function overrides base class to create additional linear projection
...
@@ -67,7 +67,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
...
@@ -67,7 +67,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
Args:
Args:
qkv_rank: the rank of query, key, value tensors after projection.
qkv_rank: the rank of query, key, value tensors after projection.
"""
"""
super
(
TalkingHeadsAttention
,
self
).
_
build_attention
(
qkv_rank
)
super
(
TalkingHeadsAttention
,
self
).
build_attention
(
qkv_rank
)
# Build an equation:
# Build an equation:
# (<batch_dims>, num_heads_a, ...),(num_heads_a, num_heads_b) ->
# (<batch_dims>, num_heads_a, ...),(num_heads_a, num_heads_b) ->
...
@@ -103,11 +103,11 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
...
@@ -103,11 +103,11 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
trainable
=
True
)
trainable
=
True
)
def
_
compute_attention
(
self
,
def
compute_attention
(
self
,
query_tensor
,
query_tensor
,
key_tensor
,
key_tensor
,
value_tensor
,
value_tensor
,
attention_mask
=
None
):
attention_mask
=
None
):
"""Applies Dot-product attention with query, key, value tensors.
"""Applies Dot-product attention with query, key, value tensors.
This function overrides base class to apply additional linear projection
This function overrides base class to apply additional linear projection
...
...
official/nlp/modeling/layers/talking_heads_attention_test.py
View file @
0cceabfc
...
@@ -46,7 +46,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
...
@@ -46,7 +46,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
# Create a 3-dimensional input (the first dimension is implicit).
# Create a 3-dimensional input (the first dimension is implicit).
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
value
=
tf
.
keras
.
Input
(
shape
=
(
20
,
80
))
value
=
tf
.
keras
.
Input
(
shape
=
(
20
,
80
))
output
=
test_layer
(
[
query
,
value
]
)
output
=
test_layer
(
query
=
query
,
value
=
value
)
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
]
+
output_dims
)
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
]
+
output_dims
)
def
test_non_masked_self_attention
(
self
):
def
test_non_masked_self_attention
(
self
):
...
@@ -55,7 +55,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
...
@@ -55,7 +55,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
num_heads
=
12
,
key_size
=
64
)
num_heads
=
12
,
key_size
=
64
)
# Create a 3-dimensional input (the first dimension is implicit).
# Create a 3-dimensional input (the first dimension is implicit).
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
output
=
test_layer
(
[
query
,
query
]
)
output
=
test_layer
(
query
=
query
,
value
=
query
)
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
def
test_attention_scores
(
self
):
def
test_attention_scores
(
self
):
...
@@ -64,7 +64,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
...
@@ -64,7 +64,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
num_heads
=
12
,
key_size
=
64
,
return_attention_scores
=
True
)
num_heads
=
12
,
key_size
=
64
,
return_attention_scores
=
True
)
# Create a 3-dimensional input (the first dimension is implicit).
# Create a 3-dimensional input (the first dimension is implicit).
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
output
,
coef
=
test_layer
(
[
query
,
query
]
)
output
,
coef
=
test_layer
(
query
=
query
,
value
=
query
)
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
self
.
assertEqual
(
coef
.
shape
.
as_list
(),
[
None
,
12
,
40
,
40
])
self
.
assertEqual
(
coef
.
shape
.
as_list
(),
[
None
,
12
,
40
,
40
])
...
@@ -78,7 +78,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
...
@@ -78,7 +78,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
query
=
tf
.
keras
.
Input
(
shape
=
(
4
,
8
))
query
=
tf
.
keras
.
Input
(
shape
=
(
4
,
8
))
value
=
tf
.
keras
.
Input
(
shape
=
(
2
,
8
))
value
=
tf
.
keras
.
Input
(
shape
=
(
2
,
8
))
mask_tensor
=
tf
.
keras
.
Input
(
shape
=
(
4
,
2
))
mask_tensor
=
tf
.
keras
.
Input
(
shape
=
(
4
,
2
))
output
=
test_layer
(
[
query
,
value
],
mask_tensor
)
output
=
test_layer
(
query
=
query
,
value
=
value
,
attention_mask
=
mask_tensor
)
# Create a model containing the test layer.
# Create a model containing the test layer.
model
=
tf
.
keras
.
Model
([
query
,
value
,
mask_tensor
],
output
)
model
=
tf
.
keras
.
Model
([
query
,
value
,
mask_tensor
],
output
)
...
@@ -102,7 +102,8 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
...
@@ -102,7 +102,8 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
# Tests the layer with three inputs: Q, K, V.
# Tests the layer with three inputs: Q, K, V.
key
=
tf
.
keras
.
Input
(
shape
=
(
2
,
8
))
key
=
tf
.
keras
.
Input
(
shape
=
(
2
,
8
))
output
=
test_layer
([
query
,
value
,
key
],
mask_tensor
)
output
=
test_layer
(
query
=
query
,
value
=
value
,
key
=
key
,
attention_mask
=
mask_tensor
)
model
=
tf
.
keras
.
Model
([
query
,
value
,
key
,
mask_tensor
],
output
)
model
=
tf
.
keras
.
Model
([
query
,
value
,
key
,
mask_tensor
],
output
)
masked_output_data
=
model
.
predict
([
from_data
,
to_data
,
to_data
,
mask_data
])
masked_output_data
=
model
.
predict
([
from_data
,
to_data
,
to_data
,
mask_data
])
...
@@ -127,7 +128,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
...
@@ -127,7 +128,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
))
kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
))
# Create a 3-dimensional input (the first dimension is implicit).
# Create a 3-dimensional input (the first dimension is implicit).
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
output
=
test_layer
(
[
query
,
query
]
)
output
=
test_layer
(
query
=
query
,
value
=
query
)
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
@
parameterized
.
named_parameters
(
@
parameterized
.
named_parameters
(
...
@@ -149,11 +150,12 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
...
@@ -149,11 +150,12 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
# Invoke the data with a random set of mask data. This should mask at least
# Invoke the data with a random set of mask data. This should mask at least
# one element.
# one element.
mask_data
=
np
.
random
.
randint
(
2
,
size
=
mask_shape
).
astype
(
"bool"
)
mask_data
=
np
.
random
.
randint
(
2
,
size
=
mask_shape
).
astype
(
"bool"
)
output
=
test_layer
(
[
query
,
value
],
mask_data
)
output
=
test_layer
(
query
=
query
,
value
=
value
,
attention_mask
=
mask_data
)
# Invoke the same data, but with a null mask (where no elements are masked).
# Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data
=
np
.
ones
(
mask_shape
)
null_mask_data
=
np
.
ones
(
mask_shape
)
unmasked_output
=
test_layer
([
query
,
value
],
null_mask_data
)
unmasked_output
=
test_layer
(
query
=
query
,
value
=
value
,
attention_mask
=
null_mask_data
)
# Because one data is masked and one is not, the outputs should not be the
# Because one data is masked and one is not, the outputs should not be the
# same.
# same.
self
.
assertNotAllClose
(
output
,
unmasked_output
)
self
.
assertNotAllClose
(
output
,
unmasked_output
)
...
...
official/nlp/modeling/layers/transformer.py
View file @
0cceabfc
...
@@ -23,7 +23,7 @@ import gin
...
@@ -23,7 +23,7 @@ import gin
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.nlp.modeling.layers
import
attention
from
official.nlp.modeling.layers
import
attention
from
official.nlp.modeling.layers
import
dense_einsum
from
official.nlp.modeling.layers
import
multi_channel_attention
from
official.nlp.modeling.layers.util
import
tf_function_if_eager
from
official.nlp.modeling.layers.util
import
tf_function_if_eager
...
@@ -49,6 +49,12 @@ class Transformer(tf.keras.layers.Layer):
...
@@ -49,6 +49,12 @@ class Transformer(tf.keras.layers.Layer):
activity_regularizer: Regularizer for dense layer activity.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
use_bias: Whether to enable use_bias in attention layer. If set False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate dense
layers. If set False, output of attention and intermediate dense layers is
normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -65,6 +71,9 @@ class Transformer(tf.keras.layers.Layer):
...
@@ -65,6 +71,9 @@ class Transformer(tf.keras.layers.Layer):
activity_regularizer
=
None
,
activity_regularizer
=
None
,
kernel_constraint
=
None
,
kernel_constraint
=
None
,
bias_constraint
=
None
,
bias_constraint
=
None
,
use_bias
=
True
,
norm_first
=
False
,
norm_epsilon
=
1e-12
,
**
kwargs
):
**
kwargs
):
super
(
Transformer
,
self
).
__init__
(
**
kwargs
)
super
(
Transformer
,
self
).
__init__
(
**
kwargs
)
...
@@ -78,8 +87,12 @@ class Transformer(tf.keras.layers.Layer):
...
@@ -78,8 +87,12 @@ class Transformer(tf.keras.layers.Layer):
self
.
_bias_initializer
=
tf
.
keras
.
initializers
.
get
(
bias_initializer
)
self
.
_bias_initializer
=
tf
.
keras
.
initializers
.
get
(
bias_initializer
)
self
.
_kernel_regularizer
=
tf
.
keras
.
regularizers
.
get
(
kernel_regularizer
)
self
.
_kernel_regularizer
=
tf
.
keras
.
regularizers
.
get
(
kernel_regularizer
)
self
.
_bias_regularizer
=
tf
.
keras
.
regularizers
.
get
(
bias_regularizer
)
self
.
_bias_regularizer
=
tf
.
keras
.
regularizers
.
get
(
bias_regularizer
)
self
.
_activity_regularizer
=
tf
.
keras
.
regularizers
.
get
(
activity_regularizer
)
self
.
_kernel_constraint
=
tf
.
keras
.
constraints
.
get
(
kernel_constraint
)
self
.
_kernel_constraint
=
tf
.
keras
.
constraints
.
get
(
kernel_constraint
)
self
.
_bias_constraint
=
tf
.
keras
.
constraints
.
get
(
bias_constraint
)
self
.
_bias_constraint
=
tf
.
keras
.
constraints
.
get
(
bias_constraint
)
self
.
_use_bias
=
use_bias
self
.
_norm_first
=
norm_first
self
.
_norm_epsilon
=
norm_epsilon
def
build
(
self
,
input_shape
):
def
build
(
self
,
input_shape
):
input_tensor
=
input_shape
[
0
]
if
len
(
input_shape
)
==
2
else
input_shape
input_tensor
=
input_shape
[
0
]
if
len
(
input_shape
)
==
2
else
input_shape
...
@@ -104,23 +117,21 @@ class Transformer(tf.keras.layers.Layer):
...
@@ -104,23 +117,21 @@ class Transformer(tf.keras.layers.Layer):
"The input size (%d) is not a multiple of the number of attention "
"The input size (%d) is not a multiple of the number of attention "
"heads (%d)"
%
(
hidden_size
,
self
.
_num_heads
))
"heads (%d)"
%
(
hidden_size
,
self
.
_num_heads
))
self
.
_attention_head_size
=
int
(
hidden_size
//
self
.
_num_heads
)
self
.
_attention_head_size
=
int
(
hidden_size
//
self
.
_num_heads
)
common_kwargs
=
dict
(
self
.
_attention_layer
=
attention
.
MultiHeadAttention
(
num_heads
=
self
.
_num_heads
,
key_size
=
self
.
_attention_head_size
,
dropout
=
self
.
_attention_dropout_rate
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
bias_constraint
=
self
.
_bias_constraint
)
name
=
"self_attention"
)
self
.
_attention_layer
=
attention
.
MultiHeadAttention
(
# pylint: disable=protected-access
num_heads
=
self
.
_num_heads
,
self
.
_attention_layer
.
build
([
input_tensor_shape
]
*
3
)
key_size
=
self
.
_attention_head_size
,
self
.
_attention_output_dense
=
self
.
_attention_layer
.
_output_dense
dropout
=
self
.
_attention_dropout_rate
,
# pylint: enable=protected-access
use_bias
=
self
.
_use_bias
,
name
=
"self_attention"
,
**
common_kwargs
)
self
.
_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
self
.
_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
# Use float32 in layernorm for numeric stability.
# Use float32 in layernorm for numeric stability.
# It is probably safe in mixed_float16, but we haven't validated this yet.
# It is probably safe in mixed_float16, but we haven't validated this yet.
...
@@ -128,19 +139,14 @@ class Transformer(tf.keras.layers.Layer):
...
@@ -128,19 +139,14 @@ class Transformer(tf.keras.layers.Layer):
tf
.
keras
.
layers
.
LayerNormalization
(
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
"self_attention_layer_norm"
,
name
=
"self_attention_layer_norm"
,
axis
=-
1
,
axis
=-
1
,
epsilon
=
1e-12
,
epsilon
=
self
.
_norm_epsilon
,
dtype
=
tf
.
float32
))
dtype
=
tf
.
float32
))
self
.
_intermediate_dense
=
dense_einsum
.
DenseEinsum
(
self
.
_intermediate_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
output_shape
=
self
.
_intermediate_size
,
"abc,cd->abd"
,
activation
=
None
,
output_shape
=
(
None
,
self
.
_intermediate_size
),
kernel_initializer
=
self
.
_kernel_initializer
,
bias_axes
=
"d"
,
bias_initializer
=
self
.
_bias_initializer
,
name
=
"intermediate"
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
**
common_kwargs
)
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"intermediate"
)
policy
=
tf
.
keras
.
mixed_precision
.
experimental
.
global_policy
()
policy
=
tf
.
keras
.
mixed_precision
.
experimental
.
global_policy
()
if
policy
.
name
==
"mixed_bfloat16"
:
if
policy
.
name
==
"mixed_bfloat16"
:
# bfloat16 causes BERT with the LAMB optimizer to not converge
# bfloat16 causes BERT with the LAMB optimizer to not converge
...
@@ -149,20 +155,19 @@ class Transformer(tf.keras.layers.Layer):
...
@@ -149,20 +155,19 @@ class Transformer(tf.keras.layers.Layer):
policy
=
tf
.
float32
policy
=
tf
.
float32
self
.
_intermediate_activation_layer
=
tf
.
keras
.
layers
.
Activation
(
self
.
_intermediate_activation_layer
=
tf
.
keras
.
layers
.
Activation
(
self
.
_intermediate_activation
,
dtype
=
policy
)
self
.
_intermediate_activation
,
dtype
=
policy
)
self
.
_output_dense
=
dense_einsum
.
DenseEinsum
(
self
.
_output_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
output_shape
=
hidden_size
,
"abc,cd->abd"
,
kernel_initializer
=
self
.
_kernel_initializer
,
output_shape
=
(
None
,
hidden_size
),
bias_initializer
=
self
.
_bias_initializer
,
bias_axes
=
"d"
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
name
=
"output"
,
bias_regularizer
=
self
.
_bias_regularizer
,
**
common_kwargs
)
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"output"
)
self
.
_output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
self
.
_output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
# Use float32 in layernorm for numeric stability.
# Use float32 in layernorm for numeric stability.
self
.
_output_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
self
.
_output_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
"output_layer_norm"
,
axis
=-
1
,
epsilon
=
1e-12
,
dtype
=
tf
.
float32
)
name
=
"output_layer_norm"
,
axis
=-
1
,
epsilon
=
self
.
_norm_epsilon
,
dtype
=
tf
.
float32
)
super
(
Transformer
,
self
).
build
(
input_shape
)
super
(
Transformer
,
self
).
build
(
input_shape
)
...
@@ -193,7 +198,13 @@ class Transformer(tf.keras.layers.Layer):
...
@@ -193,7 +198,13 @@ class Transformer(tf.keras.layers.Layer):
"kernel_constraint"
:
"kernel_constraint"
:
tf
.
keras
.
constraints
.
serialize
(
self
.
_kernel_constraint
),
tf
.
keras
.
constraints
.
serialize
(
self
.
_kernel_constraint
),
"bias_constraint"
:
"bias_constraint"
:
tf
.
keras
.
constraints
.
serialize
(
self
.
_bias_constraint
)
tf
.
keras
.
constraints
.
serialize
(
self
.
_bias_constraint
),
"use_bias"
:
self
.
_use_bias
,
"norm_first"
:
self
.
_norm_first
,
"norm_epsilon"
:
self
.
_norm_epsilon
}
}
base_config
=
super
(
Transformer
,
self
).
get_config
()
base_config
=
super
(
Transformer
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
...
@@ -208,13 +219,22 @@ class Transformer(tf.keras.layers.Layer):
...
@@ -208,13 +219,22 @@ class Transformer(tf.keras.layers.Layer):
target_tensor
=
input_tensor
[:,
0
:
self
.
_output_range
,
:]
target_tensor
=
input_tensor
[:,
0
:
self
.
_output_range
,
:]
attention_mask
=
attention_mask
[:,
0
:
self
.
_output_range
,
:]
attention_mask
=
attention_mask
[:,
0
:
self
.
_output_range
,
:]
else
:
else
:
if
self
.
_norm_first
:
source_tensor
=
input_tensor
input_tensor
=
self
.
_attention_layer_norm
(
input_tensor
)
target_tensor
=
input_tensor
target_tensor
=
input_tensor
attention_inputs
=
[
target_tensor
,
input_tensor
]
attention_output
=
self
.
_attention_layer
(
attention_inputs
,
attention_mask
)
attention_output
=
self
.
_attention_layer
(
query
=
target_tensor
,
value
=
input_tensor
,
attention_mask
=
attention_mask
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
attention_output
=
self
.
_attention_layer_norm
(
target_tensor
+
if
self
.
_norm_first
:
attention_output
)
attention_output
=
source_tensor
+
attention_output
else
:
attention_output
=
self
.
_attention_layer_norm
(
target_tensor
+
attention_output
)
if
self
.
_norm_first
:
source_attention_output
=
attention_output
attention_output
=
self
.
_output_layer_norm
(
attention_output
)
intermediate_output
=
self
.
_intermediate_dense
(
attention_output
)
intermediate_output
=
self
.
_intermediate_dense
(
attention_output
)
intermediate_output
=
self
.
_intermediate_activation_layer
(
intermediate_output
=
self
.
_intermediate_activation_layer
(
intermediate_output
)
intermediate_output
)
...
@@ -224,7 +244,10 @@ class Transformer(tf.keras.layers.Layer):
...
@@ -224,7 +244,10 @@ class Transformer(tf.keras.layers.Layer):
# is always fp32 for now. Cast layer_output to fp32 for the subsequent
# is always fp32 for now. Cast layer_output to fp32 for the subsequent
# add.
# add.
layer_output
=
tf
.
cast
(
layer_output
,
tf
.
float32
)
layer_output
=
tf
.
cast
(
layer_output
,
tf
.
float32
)
layer_output
=
self
.
_output_layer_norm
(
layer_output
+
attention_output
)
if
self
.
_norm_first
:
layer_output
=
source_attention_output
+
layer_output
else
:
layer_output
=
self
.
_output_layer_norm
(
layer_output
+
attention_output
)
return
layer_output
return
layer_output
...
@@ -236,3 +259,259 @@ class CompiledTransformer(Transformer):
...
@@ -236,3 +259,259 @@ class CompiledTransformer(Transformer):
@
tf_function_if_eager
(
experimental_compile
=
True
)
@
tf_function_if_eager
(
experimental_compile
=
True
)
def
call
(
self
,
inputs
):
def
call
(
self
,
inputs
):
return
super
(
CompiledTransformer
,
self
).
call
(
inputs
)
return
super
(
CompiledTransformer
,
self
).
call
(
inputs
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
class
TransformerDecoderLayer
(
tf
.
keras
.
layers
.
Layer
):
"""Single transformer layer for decoder.
It has three sub-layers:
(1) a multi-head self-attention mechanism.
(2) a encoder-decoder attention.
(3) a positionwise fully connected feed-forward network.
Arguments:
num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate layer.
intermediate_activation: Activation for the intermediate layer.
dropout_rate: Dropout probability for the post-attention and output dropout.
attention_dropout_rate: Dropout probability for within the attention layer.
multi_channel_cross_attention: Whether to use `MultiChannelAttention` for
cross-attention between target sequences and source sequences.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
bias_regularizer: Regularizer for dense layer biases.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
use_bias: Whether to enable use_bias in attention layer. If set False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate dense
layers. If set False, output of attention and intermediate dense layers is
normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
"""
def
__init__
(
self
,
num_attention_heads
,
intermediate_size
,
intermediate_activation
,
dropout_rate
=
0.0
,
attention_dropout_rate
=
0.0
,
multi_channel_cross_attention
=
False
,
kernel_initializer
=
"glorot_uniform"
,
bias_initializer
=
"zeros"
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
activity_regularizer
=
None
,
kernel_constraint
=
None
,
bias_constraint
=
None
,
use_bias
=
True
,
norm_first
=
False
,
norm_epsilon
=
1e-12
,
**
kwargs
):
super
(
TransformerDecoderLayer
,
self
).
__init__
(
**
kwargs
)
self
.
num_attention_heads
=
num_attention_heads
self
.
intermediate_size
=
intermediate_size
self
.
intermediate_activation
=
tf
.
keras
.
activations
.
get
(
intermediate_activation
)
self
.
dropout_rate
=
dropout_rate
self
.
attention_dropout_rate
=
attention_dropout_rate
self
.
multi_channel_cross_attention
=
multi_channel_cross_attention
self
.
_kernel_initializer
=
tf
.
keras
.
initializers
.
get
(
kernel_initializer
)
self
.
_bias_initializer
=
tf
.
keras
.
initializers
.
get
(
bias_initializer
)
self
.
_kernel_regularizer
=
tf
.
keras
.
regularizers
.
get
(
kernel_regularizer
)
self
.
_bias_regularizer
=
tf
.
keras
.
regularizers
.
get
(
bias_regularizer
)
self
.
_activity_regularizer
=
tf
.
keras
.
regularizers
.
get
(
activity_regularizer
)
self
.
_kernel_constraint
=
tf
.
keras
.
constraints
.
get
(
kernel_constraint
)
self
.
_bias_constraint
=
tf
.
keras
.
constraints
.
get
(
bias_constraint
)
self
.
_use_bias
=
use_bias
self
.
_norm_first
=
norm_first
self
.
_norm_epsilon
=
norm_epsilon
if
self
.
multi_channel_cross_attention
:
self
.
_cross_attention_cls
=
multi_channel_attention
.
MultiChannelAttention
else
:
self
.
_cross_attention_cls
=
attention
.
MultiHeadAttention
def
build
(
self
,
input_shape
):
target_tensor_shape
=
tf
.
TensorShape
(
input_shape
[
0
])
if
len
(
target_tensor_shape
)
!=
3
:
raise
ValueError
(
"TransformerLayer expects a three-dimensional input of "
"shape [batch, sequence, width]."
)
hidden_size
=
target_tensor_shape
[
2
]
if
hidden_size
%
self
.
num_attention_heads
!=
0
:
raise
ValueError
(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)"
%
(
hidden_size
,
self
.
num_attention_heads
))
self
.
attention_head_size
=
int
(
hidden_size
/
self
.
num_attention_heads
)
common_kwargs
=
dict
(
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
)
# Self attention.
self
.
self_attention
=
attention
.
CachedAttention
(
num_heads
=
self
.
num_attention_heads
,
key_size
=
self
.
attention_head_size
,
dropout
=
self
.
attention_dropout_rate
,
use_bias
=
self
.
_use_bias
,
name
=
"self_attention"
,
**
common_kwargs
)
self
.
self_attention_output_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
"abc,cd->abd"
,
output_shape
=
(
None
,
hidden_size
),
bias_axes
=
"d"
,
name
=
"output"
,
**
common_kwargs
)
self
.
self_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
dropout_rate
)
self
.
self_attention_layer_norm
=
(
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
"self_attention_layer_norm"
,
axis
=-
1
,
epsilon
=
self
.
_norm_epsilon
))
# Encoder-decoder attention.
self
.
encdec_attention
=
self
.
_cross_attention_cls
(
num_heads
=
self
.
num_attention_heads
,
key_size
=
self
.
attention_head_size
,
dropout
=
self
.
attention_dropout_rate
,
output_shape
=
hidden_size
,
use_bias
=
self
.
_use_bias
,
name
=
"attention/encdec"
,
**
common_kwargs
)
self
.
encdec_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
dropout_rate
)
self
.
encdec_attention_layer_norm
=
(
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
"attention/encdec_output_layer_norm"
,
axis
=-
1
,
epsilon
=
self
.
_norm_epsilon
))
# Feed-forward projection.
self
.
intermediate_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
"abc,cd->abd"
,
output_shape
=
(
None
,
self
.
intermediate_size
),
bias_axes
=
"d"
,
name
=
"intermediate"
,
**
common_kwargs
)
self
.
intermediate_activation_layer
=
tf
.
keras
.
layers
.
Activation
(
self
.
intermediate_activation
)
self
.
output_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
"abc,cd->abd"
,
output_shape
=
(
None
,
hidden_size
),
bias_axes
=
"d"
,
name
=
"output"
,
**
common_kwargs
)
self
.
output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
dropout_rate
)
self
.
output_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
"output_layer_norm"
,
axis
=-
1
,
epsilon
=
self
.
_norm_epsilon
)
super
(
TransformerDecoderLayer
,
self
).
build
(
input_shape
)
def
get_config
(
self
):
config
=
{
"num_attention_heads"
:
self
.
num_attention_heads
,
"intermediate_size"
:
self
.
intermediate_size
,
"intermediate_activation"
:
self
.
intermediate_activation
,
"dropout_rate"
:
self
.
dropout_rate
,
"attention_dropout_rate"
:
self
.
attention_dropout_rate
,
"multi_channel_cross_attention"
:
self
.
multi_channel_cross_attention
,
"kernel_initializer"
:
tf
.
keras
.
initializers
.
serialize
(
self
.
_kernel_initializer
),
"bias_initializer"
:
tf
.
keras
.
initializers
.
serialize
(
self
.
_bias_initializer
),
"kernel_regularizer"
:
tf
.
keras
.
regularizers
.
serialize
(
self
.
_kernel_regularizer
),
"bias_regularizer"
:
tf
.
keras
.
regularizers
.
serialize
(
self
.
_bias_regularizer
),
"activity_regularizer"
:
tf
.
keras
.
regularizers
.
serialize
(
self
.
_activity_regularizer
),
"kernel_constraint"
:
tf
.
keras
.
constraints
.
serialize
(
self
.
_kernel_constraint
),
"bias_constraint"
:
tf
.
keras
.
constraints
.
serialize
(
self
.
_bias_constraint
),
"use_bias"
:
self
.
_use_bias
,
"norm_first"
:
self
.
_norm_first
,
"norm_epsilon"
:
self
.
_norm_epsilon
}
base_config
=
super
(
TransformerDecoderLayer
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
common_layers_with_encoder
(
self
):
"""Gets layer objects that can make a Transformer encoder block."""
return
[
self
.
self_attention
,
self
.
self_attention_layer_norm
,
self
.
intermediate_dense
,
self
.
output_dense
,
self
.
output_layer_norm
]
def
call
(
self
,
inputs
,
cache
=
None
,
decode_loop_step
=
None
):
if
self
.
multi_channel_cross_attention
:
if
len
(
inputs
)
!=
5
:
raise
ValueError
(
"TransformerDecoderLayer must have 5 inputs, when it uses "
"multi_channel_cross_attention. But it got: %d"
%
len
(
inputs
))
elif
len
(
inputs
)
!=
4
:
raise
ValueError
(
"TransformerDecoderLayer must have 4 inputs, but it got: %d"
%
len
(
inputs
))
input_tensor
,
memory
,
attention_mask
,
self_attention_mask
=
inputs
[:
4
]
source_tensor
=
input_tensor
if
self
.
_norm_first
:
input_tensor
=
self
.
self_attention_layer_norm
(
input_tensor
)
self_attention_output
,
cache
=
self
.
self_attention
(
query
=
input_tensor
,
value
=
input_tensor
,
attention_mask
=
self_attention_mask
,
cache
=
cache
,
decode_loop_step
=
decode_loop_step
)
self_attention_output
=
self
.
self_attention_dropout
(
self_attention_output
)
if
self
.
_norm_first
:
self_attention_output
=
source_tensor
+
self_attention_output
else
:
self_attention_output
=
self
.
self_attention_layer_norm
(
input_tensor
+
self_attention_output
)
if
self
.
_norm_first
:
source_self_attention_output
=
self_attention_output
self_attention_output
=
self
.
encdec_attention_layer_norm
(
self_attention_output
)
cross_attn_inputs
=
dict
(
query
=
self_attention_output
,
value
=
memory
,
attention_mask
=
attention_mask
)
if
self
.
multi_channel_cross_attention
:
# Accesses the 5-th input tensor for the doc-attention probabilities.
cross_attn_inputs
[
"context_attention_weights"
]
=
inputs
[
-
1
]
attention_output
=
self
.
encdec_attention
(
**
cross_attn_inputs
)
attention_output
=
self
.
encdec_attention_dropout
(
attention_output
)
if
self
.
_norm_first
:
attention_output
=
source_self_attention_output
+
attention_output
else
:
attention_output
=
self
.
encdec_attention_layer_norm
(
self_attention_output
+
attention_output
)
if
self
.
_norm_first
:
source_attention_output
=
attention_output
attention_output
=
self
.
output_layer_norm
(
attention_output
)
intermediate_output
=
self
.
intermediate_dense
(
attention_output
)
intermediate_output
=
self
.
intermediate_activation_layer
(
intermediate_output
)
layer_output
=
self
.
output_dense
(
intermediate_output
)
layer_output
=
self
.
output_dropout
(
layer_output
)
if
self
.
_norm_first
:
layer_output
=
source_attention_output
+
layer_output
else
:
layer_output
=
self
.
output_layer_norm
(
layer_output
+
attention_output
)
return
layer_output
,
cache
official/nlp/modeling/layers/transformer_scaffold.py
View file @
0cceabfc
...
@@ -262,9 +262,8 @@ class TransformerScaffold(tf.keras.layers.Layer):
...
@@ -262,9 +262,8 @@ class TransformerScaffold(tf.keras.layers.Layer):
else
:
else
:
input_tensor
,
attention_mask
=
(
inputs
,
None
)
input_tensor
,
attention_mask
=
(
inputs
,
None
)
attention_inputs
=
[
input_tensor
,
input_tensor
]
attention_output
=
self
.
_attention_layer
(
query
=
input_tensor
,
value
=
input_tensor
,
attention_mask
=
attention_mask
)
attention_output
=
self
.
_attention_layer
(
attention_inputs
,
attention_mask
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
attention_output
=
self
.
_attention_layer_norm
(
input_tensor
+
attention_output
=
self
.
_attention_layer_norm
(
input_tensor
+
attention_output
)
attention_output
)
...
...
official/nlp/modeling/layers/transformer_scaffold_test.py
View file @
0cceabfc
...
@@ -39,10 +39,10 @@ class ValidatedAttentionLayer(attention.MultiHeadAttention):
...
@@ -39,10 +39,10 @@ class ValidatedAttentionLayer(attention.MultiHeadAttention):
super
(
ValidatedAttentionLayer
,
self
).
__init__
(
**
kwargs
)
super
(
ValidatedAttentionLayer
,
self
).
__init__
(
**
kwargs
)
self
.
list
=
call_list
self
.
list
=
call_list
def
call
(
self
,
inputs
,
attention_mask
=
None
):
def
call
(
self
,
query
,
value
,
attention_mask
=
None
):
self
.
list
.
append
(
True
)
self
.
list
.
append
(
True
)
return
super
(
ValidatedAttentionLayer
,
self
).
call
(
return
super
(
ValidatedAttentionLayer
,
self
).
call
(
inputs
,
attention_mask
=
attention_mask
)
query
,
value
,
attention_mask
=
attention_mask
)
def
get_config
(
self
):
def
get_config
(
self
):
config
=
super
(
ValidatedAttentionLayer
,
self
).
get_config
()
config
=
super
(
ValidatedAttentionLayer
,
self
).
get_config
()
...
...
official/nlp/modeling/layers/transformer_test.py
View file @
0cceabfc
...
@@ -152,7 +152,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
...
@@ -152,7 +152,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
_
=
new_layer
([
input_data
,
mask_data
])
_
=
new_layer
([
input_data
,
mask_data
])
new_layer
.
set_weights
(
test_layer
.
get_weights
())
new_layer
.
set_weights
(
test_layer
.
get_weights
())
new_output_tensor
=
new_layer
([
input_data
,
mask_data
])
new_output_tensor
=
new_layer
([
input_data
,
mask_data
])
self
.
assertAllClose
(
new_output_tensor
,
output_tensor
[:,
0
:
1
,
:])
self
.
assertAllClose
(
new_output_tensor
,
output_tensor
[:,
0
:
1
,
:],
atol
=
5e-5
,
rtol
=
0.003
)
def
test_layer_invocation_with_float16_dtype
(
self
,
transformer_cls
):
def
test_layer_invocation_with_float16_dtype
(
self
,
transformer_cls
):
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
'mixed_float16'
)
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
'mixed_float16'
)
...
@@ -215,5 +216,113 @@ class TransformerLayerTest(keras_parameterized.TestCase):
...
@@ -215,5 +216,113 @@ class TransformerLayerTest(keras_parameterized.TestCase):
self
.
assertAllEqual
([
1
,
input_length
,
width
],
output_data
.
shape
)
self
.
assertAllEqual
([
1
,
input_length
,
width
],
output_data
.
shape
)
@
keras_parameterized
.
run_all_keras_modes
class
TransformerArgumentTest
(
keras_parameterized
.
TestCase
):
def
test_use_bias_norm_first
(
self
):
num_attention_heads
=
2
hidden_size
=
16
encoder_block
=
transformer
.
Transformer
(
num_attention_heads
=
num_attention_heads
,
intermediate_size
=
32
,
intermediate_activation
=
'relu'
,
dropout_rate
=
0.1
,
attention_dropout_rate
=
0.1
,
use_bias
=
False
,
norm_first
=
True
,
norm_epsilon
=
1e-6
)
# Forward path.
dummy_tensor
=
tf
.
zeros
([
2
,
4
,
16
],
dtype
=
tf
.
float32
)
dummy_mask
=
tf
.
zeros
([
2
,
4
,
4
],
dtype
=
tf
.
float32
)
inputs
=
[
dummy_tensor
,
dummy_mask
]
output
=
encoder_block
(
inputs
)
self
.
assertEqual
(
output
.
shape
,
(
2
,
4
,
hidden_size
))
def
test_get_config
(
self
):
num_attention_heads
=
2
encoder_block
=
transformer
.
Transformer
(
num_attention_heads
=
num_attention_heads
,
intermediate_size
=
32
,
intermediate_activation
=
'relu'
,
dropout_rate
=
0.1
,
attention_dropout_rate
=
0.1
,
use_bias
=
False
,
norm_first
=
True
,
norm_epsilon
=
1e-6
)
encoder_block_config
=
encoder_block
.
get_config
()
new_encoder_block
=
transformer
.
Transformer
.
from_config
(
encoder_block_config
)
self
.
assertEqual
(
encoder_block_config
,
new_encoder_block
.
get_config
())
def
_create_cache
(
batch_size
,
init_decode_length
,
num_heads
,
head_size
):
return
{
'key'
:
tf
.
zeros
([
batch_size
,
init_decode_length
,
num_heads
,
head_size
],
dtype
=
tf
.
float32
),
'value'
:
tf
.
zeros
([
batch_size
,
init_decode_length
,
num_heads
,
head_size
],
dtype
=
tf
.
float32
)
}
@
keras_parameterized
.
run_all_keras_modes
class
TransformerDecoderLayerTest
(
keras_parameterized
.
TestCase
):
def
test_decoder_block_with_cache
(
self
):
num_attention_heads
=
2
hidden_size
=
16
decoder_block
=
transformer
.
TransformerDecoderLayer
(
num_attention_heads
=
num_attention_heads
,
intermediate_size
=
32
,
intermediate_activation
=
'relu'
,
dropout_rate
=
0.1
,
attention_dropout_rate
=
0.1
)
# Forward path.
dummy_tensor
=
tf
.
zeros
([
2
,
4
,
16
],
dtype
=
tf
.
float32
)
dummy_mask
=
tf
.
zeros
([
2
,
4
,
4
],
dtype
=
tf
.
float32
)
inputs
=
[
dummy_tensor
,
dummy_tensor
,
dummy_mask
,
dummy_mask
]
cache
=
_create_cache
(
2
,
0
,
num_attention_heads
,
hidden_size
//
num_attention_heads
)
output
,
cache
=
decoder_block
(
inputs
,
cache
)
self
.
assertEqual
(
output
.
shape
,
(
2
,
4
,
hidden_size
))
self
.
assertEqual
(
cache
[
'value'
].
shape
,
(
2
,
4
,
2
,
8
))
def
test_use_bias_norm_first
(
self
):
num_attention_heads
=
2
hidden_size
=
16
decoder_block
=
transformer
.
TransformerDecoderLayer
(
num_attention_heads
=
num_attention_heads
,
intermediate_size
=
32
,
intermediate_activation
=
'relu'
,
dropout_rate
=
0.1
,
attention_dropout_rate
=
0.1
,
use_bias
=
False
,
norm_first
=
True
,
norm_epsilon
=
1e-6
)
# Forward path.
dummy_tensor
=
tf
.
zeros
([
2
,
4
,
16
],
dtype
=
tf
.
float32
)
dummy_mask
=
tf
.
zeros
([
2
,
4
,
4
],
dtype
=
tf
.
float32
)
inputs
=
[
dummy_tensor
,
dummy_tensor
,
dummy_mask
,
dummy_mask
]
output
,
_
=
decoder_block
(
inputs
)
self
.
assertEqual
(
output
.
shape
,
(
2
,
4
,
hidden_size
))
def
test_get_config
(
self
):
num_attention_heads
=
2
decoder_block
=
transformer
.
TransformerDecoderLayer
(
num_attention_heads
=
num_attention_heads
,
intermediate_size
=
32
,
intermediate_activation
=
'relu'
,
dropout_rate
=
0.1
,
attention_dropout_rate
=
0.1
,
use_bias
=
False
,
norm_first
=
True
,
norm_epsilon
=
1e-6
)
decoder_block_config
=
decoder_block
.
get_config
()
new_decoder_block
=
transformer
.
TransformerDecoderLayer
.
from_config
(
decoder_block_config
)
self
.
assertEqual
(
decoder_block_config
,
new_decoder_block
.
get_config
())
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
official/nlp/modeling/losses/README.md
View file @
0cceabfc
...
@@ -4,6 +4,3 @@ Losses contains common loss computation used in NLP tasks.
...
@@ -4,6 +4,3 @@ Losses contains common loss computation used in NLP tasks.
*
`weighted_sparse_categorical_crossentropy_loss`
computes per-batch sparse
*
`weighted_sparse_categorical_crossentropy_loss`
computes per-batch sparse
categorical crossentropy loss.
categorical crossentropy loss.
*
`weighted_sparse_categorical_crossentropy_per_example_loss`
computes
per-example sparse categorical crossentropy loss.
official/nlp/modeling/losses/__init__.py
View file @
0cceabfc
...
@@ -14,4 +14,3 @@
...
@@ -14,4 +14,3 @@
# ==============================================================================
# ==============================================================================
"""Activations package definition. Subject to change."""
"""Activations package definition. Subject to change."""
from
official.nlp.modeling.losses.weighted_sparse_categorical_crossentropy
import
loss
as
weighted_sparse_categorical_crossentropy_loss
from
official.nlp.modeling.losses.weighted_sparse_categorical_crossentropy
import
loss
as
weighted_sparse_categorical_crossentropy_loss
from
official.nlp.modeling.losses.weighted_sparse_categorical_crossentropy
import
per_example_loss
as
weighted_sparse_categorical_crossentropy_per_example_loss
official/nlp/modeling/losses/weighted_sparse_categorical_crossentropy.py
View file @
0cceabfc
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""
S
parse categorical cross-entropy losses."""
"""
Weighted s
parse categorical cross-entropy losses."""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
division
...
@@ -43,37 +43,7 @@ def _validate_rank(labels, predictions, weights):
...
@@ -43,37 +43,7 @@ def _validate_rank(labels, predictions, weights):
"predictions.shape was %s."
)
%
(
labels
.
shape
,
predictions
.
shape
))
"predictions.shape was %s."
)
%
(
labels
.
shape
,
predictions
.
shape
))
def
per_example_loss
(
labels
,
predictions
,
weights
=
None
):
def
loss
(
labels
,
predictions
,
weights
=
None
,
from_logits
=
False
):
"""Calculate a per-example sparse categorical crossentropy loss.
This loss function assumes that the predictions are post-softmax.
Args:
labels: The labels to evaluate against. Should be a set of integer indices
ranging from 0 to (vocab_size-1).
predictions: The network predictions. Should have softmax already applied.
weights: An optional weight array of the same shape as the 'labels' array.
If None, all examples will be used.
Returns:
A tensor of shape predictions.shape[:-1] containing the per-example
loss.
"""
# When using these functions with the Keras core API, we will need to squeeze
# the labels tensor - Keras adds a spurious inner dimension.
labels
,
predictions
=
_adjust_labels
(
labels
,
predictions
)
_validate_rank
(
labels
,
predictions
,
weights
)
labels_one_hot
=
tf
.
one_hot
(
labels
,
predictions
.
shape
[
-
1
])
labels_one_hot
=
tf
.
cast
(
labels_one_hot
,
predictions
.
dtype
)
per_example_loss_data
=
-
tf
.
reduce_sum
(
predictions
*
labels_one_hot
,
axis
=
[
-
1
])
if
weights
is
not
None
:
weights
=
tf
.
cast
(
weights
,
per_example_loss_data
.
dtype
)
per_example_loss_data
=
weights
*
per_example_loss_data
return
per_example_loss_data
def
loss
(
labels
,
predictions
,
weights
=
None
):
"""Calculate a per-batch sparse categorical crossentropy loss.
"""Calculate a per-batch sparse categorical crossentropy loss.
This loss function assumes that the predictions are post-softmax.
This loss function assumes that the predictions are post-softmax.
...
@@ -83,6 +53,7 @@ def loss(labels, predictions, weights=None):
...
@@ -83,6 +53,7 @@ def loss(labels, predictions, weights=None):
predictions: The network predictions. Should have softmax already applied.
predictions: The network predictions. Should have softmax already applied.
weights: An optional weight array of the same shape as the 'labels' array.
weights: An optional weight array of the same shape as the 'labels' array.
If None, all examples will be used.
If None, all examples will be used.
from_logits: Whether the input predictions are logits.
Returns:
Returns:
A loss scalar.
A loss scalar.
...
@@ -95,12 +66,11 @@ def loss(labels, predictions, weights=None):
...
@@ -95,12 +66,11 @@ def loss(labels, predictions, weights=None):
labels
,
predictions
=
_adjust_labels
(
labels
,
predictions
)
labels
,
predictions
=
_adjust_labels
(
labels
,
predictions
)
_validate_rank
(
labels
,
predictions
,
weights
)
_validate_rank
(
labels
,
predictions
,
weights
)
per_example_loss_data
=
per_example_loss
(
labels
,
predictions
,
weights
)
example_losses
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
labels
,
predictions
,
from_logits
=
from_logits
)
if
weights
is
None
:
if
weights
is
None
:
return
tf
.
reduce_mean
(
per_example_loss_data
)
return
tf
.
reduce_mean
(
example_losses
)
else
:
weights
=
tf
.
cast
(
weights
,
predictions
.
dtype
)
numerator
=
tf
.
reduce_sum
(
per_example_loss_data
)
return
tf
.
math
.
divide_no_nan
(
weights
=
tf
.
cast
(
weights
,
predictions
.
dtype
)
tf
.
reduce_sum
(
example_losses
*
weights
),
tf
.
reduce_sum
(
weights
))
denominator
=
tf
.
reduce_sum
(
weights
)
+
1e-5
return
numerator
/
denominator
official/nlp/modeling/losses/weighted_sparse_categorical_crossentropy_test.py
View file @
0cceabfc
...
@@ -53,8 +53,7 @@ class ClassificationLossTest(keras_parameterized.TestCase):
...
@@ -53,8 +53,7 @@ class ClassificationLossTest(keras_parameterized.TestCase):
# Create a maskedLM from the transformer stack.
# Create a maskedLM from the transformer stack.
test_layer
=
layers
.
MaskedLM
(
test_layer
=
layers
.
MaskedLM
(
embedding_table
=
xformer_stack
.
get_embedding_table
(),
embedding_table
=
xformer_stack
.
get_embedding_table
(),
output
=
output
)
output
=
output
)
# Create a model from the masked LM layer.
# Create a model from the masked LM layer.
lm_input_tensor
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,
hidden_size
))
lm_input_tensor
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,
hidden_size
))
...
@@ -63,123 +62,6 @@ class ClassificationLossTest(keras_parameterized.TestCase):
...
@@ -63,123 +62,6 @@ class ClassificationLossTest(keras_parameterized.TestCase):
output
=
test_layer
(
lm_input_tensor
,
masked_positions
=
masked_lm_positions
)
output
=
test_layer
(
lm_input_tensor
,
masked_positions
=
masked_lm_positions
)
return
tf
.
keras
.
Model
([
lm_input_tensor
,
masked_lm_positions
],
output
)
return
tf
.
keras
.
Model
([
lm_input_tensor
,
masked_lm_positions
],
output
)
def
create_classification_model
(
self
,
input_width
,
num_classes
):
test_object
=
networks
.
Classification
(
input_width
=
input_width
,
num_classes
=
num_classes
)
# Create a 2-dimensional input (the first dimension is implicit).
pooled_data
=
tf
.
keras
.
Input
(
shape
=
(
input_width
,),
dtype
=
tf
.
float32
)
output
=
test_object
(
pooled_data
)
return
tf
.
keras
.
Model
(
pooled_data
,
output
)
def
test_per_example_loss_3d_input
(
self
):
"""Test per-example loss with a 3-dimensional input, from a masked LM."""
vocab_size
=
100
sequence_length
=
32
hidden_size
=
64
num_predictions
=
21
model
=
self
.
create_lm_model
(
vocab_size
=
vocab_size
,
sequence_length
=
sequence_length
,
hidden_size
=
hidden_size
,
num_predictions
=
num_predictions
)
# Get the output of the masked LM.
batch_size
=
3
lm_input_data
=
10
*
np
.
random
.
random_sample
(
(
batch_size
,
sequence_length
,
hidden_size
))
masked_position_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
num_predictions
))
output_data
=
model
.
predict
([
lm_input_data
,
masked_position_data
])
# Calculate per-example loss.
labels
=
np
.
random
.
randint
(
vocab_size
,
size
=
(
batch_size
,
num_predictions
))
per_example_loss_data
=
weighted_sparse_categorical_crossentropy
.
per_example_loss
(
predictions
=
output_data
,
labels
=
labels
)
# Per-example loss data should have one value per prediction, and those
# values shouldn't be zero in this case (as we're using random data).
expected_shape
=
[
batch_size
,
num_predictions
]
self
.
assertEqual
(
expected_shape
,
per_example_loss_data
.
shape
.
as_list
())
self
.
assertNotAllClose
(
tf
.
zeros_like
(
per_example_loss_data
),
per_example_loss_data
)
def
test_per_example_loss_2d_input
(
self
):
"""Test per-example loss with a 2-d input, from a classifier."""
input_width
=
512
num_classes
=
10
model
=
self
.
create_classification_model
(
input_width
,
num_classes
)
# Invoke the network as part of a Model.
batch_size
=
3
input_data
=
10
*
np
.
random
.
random_sample
((
batch_size
,
input_width
))
output_data
=
model
.
predict
(
input_data
)
# Calculate per example loss.
labels
=
np
.
random
.
randint
(
num_classes
,
size
=
(
batch_size
))
per_example_loss_data
=
weighted_sparse_categorical_crossentropy
.
per_example_loss
(
predictions
=
output_data
,
labels
=
labels
)
# Per-example loss data should have one value per batch item, and those
# values shouldn't be zero in this case (as we're using random data).
self
.
assertEqual
([
batch_size
],
per_example_loss_data
.
shape
.
as_list
())
self
.
assertNotAllClose
(
tf
.
zeros_like
(
per_example_loss_data
),
per_example_loss_data
)
def
test_per_example_loss_weights_3d_input
(
self
):
"""Test weighted per-example loss with a 3-d input, from a masked LM."""
vocab_size
=
100
sequence_length
=
32
hidden_size
=
64
num_predictions
=
21
model
=
self
.
create_lm_model
(
vocab_size
=
vocab_size
,
sequence_length
=
sequence_length
,
hidden_size
=
hidden_size
,
num_predictions
=
num_predictions
)
# Get the output of the masked LM.
batch_size
=
3
lm_input_data
=
10
*
np
.
random
.
random_sample
(
(
batch_size
,
sequence_length
,
hidden_size
))
masked_position_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
num_predictions
))
output_data
=
model
.
predict
([
lm_input_data
,
masked_position_data
])
# Calculate per-example loss with weights.
labels
=
np
.
random
.
randint
(
vocab_size
,
size
=
(
batch_size
,
num_predictions
))
weights
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
num_predictions
))
per_example_loss_data
=
weighted_sparse_categorical_crossentropy
.
per_example_loss
(
predictions
=
output_data
,
labels
=
labels
,
weights
=
weights
)
# Weighted per-example loss data should be equivalent to multiplying the
# loss tensor by the weights tensor.
expected_weighted_loss
=
per_example_loss_data
*
weights
self
.
assertAllClose
(
expected_weighted_loss
,
per_example_loss_data
)
def
test_per_example_loss_weights_2d_input
(
self
):
"""Test weighted per-example loss with a 2-d input, from a classifier."""
input_width
=
512
num_classes
=
10
model
=
self
.
create_classification_model
(
input_width
,
num_classes
)
# Invoke the network as part of a Model.
batch_size
=
3
input_data
=
10
*
np
.
random
.
random_sample
((
batch_size
,
input_width
))
output_data
=
model
.
predict
(
input_data
)
# Calculate per-example loss with weights.
labels
=
np
.
random
.
randint
(
num_classes
,
size
=
(
batch_size
))
weights
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
))
per_example_loss_data
=
weighted_sparse_categorical_crossentropy
.
per_example_loss
(
predictions
=
output_data
,
labels
=
labels
,
weights
=
weights
)
# Weighted per-example loss data should be equivalent to multiplying the
# loss tensor by the weights tensor.
expected_weighted_loss
=
per_example_loss_data
*
weights
self
.
assertAllClose
(
expected_weighted_loss
,
per_example_loss_data
)
def
test_loss_3d_input
(
self
):
def
test_loss_3d_input
(
self
):
"""Test overall loss with a 3-dimensional input, from a masked LM."""
"""Test overall loss with a 3-dimensional input, from a masked LM."""
vocab_size
=
100
vocab_size
=
100
...
@@ -213,26 +95,6 @@ class ClassificationLossTest(keras_parameterized.TestCase):
...
@@ -213,26 +95,6 @@ class ClassificationLossTest(keras_parameterized.TestCase):
self
.
assertNotAllClose
(
self
.
assertNotAllClose
(
tf
.
zeros_like
(
per_example_loss_data
),
per_example_loss_data
)
tf
.
zeros_like
(
per_example_loss_data
),
per_example_loss_data
)
def
test_loss_2d_input
(
self
):
"""Test overall loss with a 2-d input, from a classifier."""
input_width
=
512
num_classes
=
10
model
=
self
.
create_classification_model
(
input_width
,
num_classes
)
# Invoke the network as part of a Model.
batch_size
=
3
input_data
=
10
*
np
.
random
.
random_sample
((
batch_size
,
input_width
))
output_data
=
model
.
predict
(
input_data
)
# Calculate per example loss.
labels
=
np
.
random
.
randint
(
num_classes
,
size
=
(
batch_size
))
loss_data
=
weighted_sparse_categorical_crossentropy
.
loss
(
predictions
=
output_data
,
labels
=
labels
)
# Loss data should have one value only, and that value shouldn't be zero in
# this case (as we're using random data).
self
.
assertNotAllClose
(
0
,
loss_data
)
def
test_loss_weights_3d_input
(
self
):
def
test_loss_weights_3d_input
(
self
):
"""Test masked loss with a 3-dimensional input, from a masked LM."""
"""Test masked loss with a 3-dimensional input, from a masked LM."""
vocab_size
=
100
vocab_size
=
100
...
@@ -262,26 +124,6 @@ class ClassificationLossTest(keras_parameterized.TestCase):
...
@@ -262,26 +124,6 @@ class ClassificationLossTest(keras_parameterized.TestCase):
# Because the tensor is fully masked, the loss should be 0.
# Because the tensor is fully masked, the loss should be 0.
self
.
assertAllClose
(
0
,
weighted_loss_data
)
self
.
assertAllClose
(
0
,
weighted_loss_data
)
def
test_loss_weights_2d_input
(
self
):
"""Test masked loss with a 2-d input, from a classifier."""
input_width
=
512
num_classes
=
10
model
=
self
.
create_classification_model
(
input_width
,
num_classes
)
# Invoke the network as part of a Model.
batch_size
=
3
input_data
=
10
*
np
.
random
.
random_sample
((
batch_size
,
input_width
))
output_data
=
model
.
predict
(
input_data
)
# Calculate a fully masked weight tensor. This should give a loss of zero.
labels
=
np
.
random
.
randint
(
num_classes
,
size
=
(
batch_size
))
null_weights
=
np
.
zeros
((
batch_size
))
weighted_loss_data
=
weighted_sparse_categorical_crossentropy
.
loss
(
predictions
=
output_data
,
labels
=
labels
,
weights
=
null_weights
)
# Because the tensor is fully masked, the loss should be 0.
self
.
assertAllClose
(
0
,
weighted_loss_data
)
def
test_mismatched_predictions_and_labels_ranks_squeezes
(
self
):
def
test_mismatched_predictions_and_labels_ranks_squeezes
(
self
):
"""Test that the loss asserts when rank(predictions)-1 != rank(labels)."""
"""Test that the loss asserts when rank(predictions)-1 != rank(labels)."""
batch_size
=
3
batch_size
=
3
...
@@ -289,7 +131,7 @@ class ClassificationLossTest(keras_parameterized.TestCase):
...
@@ -289,7 +131,7 @@ class ClassificationLossTest(keras_parameterized.TestCase):
labels
=
np
.
random
.
randint
(
10
,
size
=
(
batch_size
,
1
))
labels
=
np
.
random
.
randint
(
10
,
size
=
(
batch_size
,
1
))
# All that this test tests is that the squeeze is successful.
# All that this test tests is that the squeeze is successful.
_
=
weighted_sparse_categorical_crossentropy
.
per_example_
loss
(
_
=
weighted_sparse_categorical_crossentropy
.
loss
(
predictions
=
output_data
,
labels
=
labels
)
predictions
=
output_data
,
labels
=
labels
)
def
test_mismatched_weights_and_labels_ranks_fail
(
self
):
def
test_mismatched_weights_and_labels_ranks_fail
(
self
):
...
@@ -299,9 +141,6 @@ class ClassificationLossTest(keras_parameterized.TestCase):
...
@@ -299,9 +141,6 @@ class ClassificationLossTest(keras_parameterized.TestCase):
labels
=
np
.
random
.
randint
(
10
,
size
=
(
batch_size
,
10
))
labels
=
np
.
random
.
randint
(
10
,
size
=
(
batch_size
,
10
))
weights
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
))
weights
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
))
with
self
.
assertRaisesRegex
(
RuntimeError
,
".*of the same rank.*"
):
_
=
weighted_sparse_categorical_crossentropy
.
per_example_loss
(
predictions
=
output_data
,
labels
=
labels
,
weights
=
weights
)
with
self
.
assertRaisesRegex
(
RuntimeError
,
".*of the same rank.*"
):
with
self
.
assertRaisesRegex
(
RuntimeError
,
".*of the same rank.*"
):
_
=
weighted_sparse_categorical_crossentropy
.
loss
(
_
=
weighted_sparse_categorical_crossentropy
.
loss
(
predictions
=
output_data
,
labels
=
labels
,
weights
=
weights
)
predictions
=
output_data
,
labels
=
labels
,
weights
=
weights
)
...
@@ -317,8 +156,6 @@ class ClassificationLossTest(keras_parameterized.TestCase):
...
@@ -317,8 +156,6 @@ class ClassificationLossTest(keras_parameterized.TestCase):
# We're not trying to validate numerical correctness, just ensure that
# We're not trying to validate numerical correctness, just ensure that
# we can in fact pass tensors to these functions without causing runtime
# we can in fact pass tensors to these functions without causing runtime
# errors from the shape checking code.
# errors from the shape checking code.
_
=
weighted_sparse_categorical_crossentropy
.
per_example_loss
(
predictions
=
output_data
,
labels
=
labels
,
weights
=
weights
)
_
=
weighted_sparse_categorical_crossentropy
.
loss
(
_
=
weighted_sparse_categorical_crossentropy
.
loss
(
predictions
=
output_data
,
labels
=
labels
,
weights
=
weights
)
predictions
=
output_data
,
labels
=
labels
,
weights
=
weights
)
...
@@ -338,20 +175,15 @@ class ClassificationLossTest(keras_parameterized.TestCase):
...
@@ -338,20 +175,15 @@ class ClassificationLossTest(keras_parameterized.TestCase):
[
-
2.7760355
,
-
1.8219438
,
-
3.0924666
,
-
1.0779881
,
-
0.9407509
]]])
[
-
2.7760355
,
-
1.8219438
,
-
3.0924666
,
-
1.0779881
,
-
0.9407509
]]])
labels
=
np
.
array
([[
4
,
0
],
[
2
,
2
],
[
2
,
1
]])
labels
=
np
.
array
([[
4
,
0
],
[
2
,
2
],
[
2
,
1
]])
# Validate that per_example loss calculations are the same.
per_example_loss_data
=
weighted_sparse_categorical_crossentropy
.
per_example_loss
(
predictions
=
output_data
,
labels
=
labels
)
expected_per_example_loss_data
=
[[
1.2923571
,
2.7117882
],
[
2.287932
,
2.287932
],
[
3.0924666
,
1.8219438
]]
self
.
assertAllClose
(
expected_per_example_loss_data
,
per_example_loss_data
)
# Validate that overall loss calculations are the same.
# Validate that overall loss calculations are the same.
weights
=
np
.
array
([[
1
,
0
],
[
0
,
0
],
[
0
,
0
]])
weights
=
np
.
array
([[
1
,
0
],
[
0
,
0
],
[
0
,
0
]])
loss_data
=
weighted_sparse_categorical_crossentropy
.
loss
(
loss_data
=
weighted_sparse_categorical_crossentropy
.
loss
(
predictions
=
output_data
,
labels
=
labels
,
weights
=
weights
)
predictions
=
output_data
,
labels
=
labels
,
weights
=
weights
,
from_logits
=
True
)
expected_loss_data
=
1.2923441
expected_loss_data
=
1.2923441
self
.
assertAllClose
(
expected_loss_data
,
loss_data
)
self
.
assertAllClose
(
expected_loss_data
,
loss_data
,
rtol
=
1e-3
)
def
test_legacy_classification_loss_compatibility
(
self
):
def
test_legacy_classification_loss_compatibility
(
self
):
"""Test to validate computational correctness during refactors."""
"""Test to validate computational correctness during refactors."""
...
@@ -362,19 +194,15 @@ class ClassificationLossTest(keras_parameterized.TestCase):
...
@@ -362,19 +194,15 @@ class ClassificationLossTest(keras_parameterized.TestCase):
[
-
1.6975292e-03
,
-
6.4009643e+00
,
-
1.0226612e+01
]])
[
-
1.6975292e-03
,
-
6.4009643e+00
,
-
1.0226612e+01
]])
labels
=
np
.
array
([
2
,
1
])
labels
=
np
.
array
([
2
,
1
])
# Validate that per_example loss calculations are the same.
per_example_loss_data
=
weighted_sparse_categorical_crossentropy
.
per_example_loss
(
predictions
=
output_data
,
labels
=
labels
)
expected_per_example_loss_data
=
[
6.4434357
,
6.4009643
]
self
.
assertAllClose
(
expected_per_example_loss_data
,
per_example_loss_data
)
# Validate that overall loss calculations are the same.
# Validate that overall loss calculations are the same.
weights
=
None
weights
=
None
loss_data
=
weighted_sparse_categorical_crossentropy
.
loss
(
loss_data
=
weighted_sparse_categorical_crossentropy
.
loss
(
predictions
=
output_data
,
labels
=
labels
,
weights
=
weights
)
predictions
=
output_data
,
labels
=
labels
,
weights
=
weights
,
from_logits
=
True
)
expected_loss_data
=
6.4222
expected_loss_data
=
6.4222
self
.
assertAllClose
(
expected_loss_data
,
loss_data
)
self
.
assertAllClose
(
expected_loss_data
,
loss_data
,
rtol
=
1e-3
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
tf
.
test
.
main
()
official/nlp/modeling/models/README.md
View file @
0cceabfc
...
@@ -10,8 +10,8 @@ model containing a single classification head using the Classification network.
...
@@ -10,8 +10,8 @@ model containing a single classification head using the Classification network.
It can be used as a regression model as well.
It can be used as a regression model as well.
*
[
`BertTokenClassifier`
](
bert_token_classifier.py
)
implements a simple token
*
[
`BertTokenClassifier`
](
bert_token_classifier.py
)
implements a simple token
classification model containing a single classification head
using th
e
classification model containing a single classification head
over the sequenc
e
TokenClassification network
.
output embeddings
.
*
[
`BertSpanLabeler`
](
bert_span_labeler.py
)
implementats a simple single-span
*
[
`BertSpanLabeler`
](
bert_span_labeler.py
)
implementats a simple single-span
start-end predictor (that is, a model that predicts two values: a start token
start-end predictor (that is, a model that predicts two values: a start token
...
...
official/nlp/modeling/models/__init__.py
View file @
0cceabfc
...
@@ -17,3 +17,4 @@ from official.nlp.modeling.models.bert_classifier import BertClassifier
...
@@ -17,3 +17,4 @@ from official.nlp.modeling.models.bert_classifier import BertClassifier
from
official.nlp.modeling.models.bert_pretrainer
import
BertPretrainer
from
official.nlp.modeling.models.bert_pretrainer
import
BertPretrainer
from
official.nlp.modeling.models.bert_span_labeler
import
BertSpanLabeler
from
official.nlp.modeling.models.bert_span_labeler
import
BertSpanLabeler
from
official.nlp.modeling.models.bert_token_classifier
import
BertTokenClassifier
from
official.nlp.modeling.models.bert_token_classifier
import
BertTokenClassifier
from
official.nlp.modeling.models.electra_pretrainer
import
ElectraPretrainer
official/nlp/modeling/models/bert_classifier.py
View file @
0cceabfc
...
@@ -12,15 +12,12 @@
...
@@ -12,15 +12,12 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""
Trainer network for BERT-style models
."""
"""
BERT cls-token classifier
."""
# pylint: disable=g-classes-have-attributes
# pylint: disable=g-classes-have-attributes
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.nlp.modeling
import
layers
from
official.nlp.modeling
import
networks
from
official.nlp.modeling
import
networks
...
@@ -36,6 +33,9 @@ class BertClassifier(tf.keras.Model):
...
@@ -36,6 +33,9 @@ class BertClassifier(tf.keras.Model):
instantiates a classification network based on the passed `num_classes`
instantiates a classification network based on the passed `num_classes`
argument. If `num_classes` is set to 1, a regression network is instantiated.
argument. If `num_classes` is set to 1, a regression network is instantiated.
*Note* that the model is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments:
Arguments:
network: A transformer network. This network should output a sequence output
network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding
and a classification output. Furthermore, it should expose its embedding
...
@@ -43,23 +43,25 @@ class BertClassifier(tf.keras.Model):
...
@@ -43,23 +43,25 @@ class BertClassifier(tf.keras.Model):
num_classes: Number of classes to predict from the classification network.
num_classes: Number of classes to predict from the classification network.
initializer: The initializer (if any) to use in the classification networks.
initializer: The initializer (if any) to use in the classification networks.
Defaults to a Glorot uniform initializer.
Defaults to a Glorot uniform initializer.
output: The output style for this network. Can be either 'logits' or
dropout_rate: The dropout probability of the cls head.
'predictions'.
use_encoder_pooler: Whether to use the pooler layer pre-defined inside
the encoder.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
network
,
network
,
num_classes
,
num_classes
,
initializer
=
'glorot_uniform'
,
initializer
=
'glorot_uniform'
,
output
=
'logits'
,
dropout_rate
=
0.1
,
dropout_rate
=
0.1
,
use_encoder_pooler
=
True
,
**
kwargs
):
**
kwargs
):
self
.
_self_setattr_tracking
=
False
self
.
_self_setattr_tracking
=
False
self
.
_network
=
network
self
.
_config
=
{
self
.
_config
=
{
'network'
:
network
,
'network'
:
network
,
'num_classes'
:
num_classes
,
'num_classes'
:
num_classes
,
'initializer'
:
initializer
,
'initializer'
:
initializer
,
'
output'
:
output
,
'
use_encoder_pooler'
:
use_encoder_pooler
,
}
}
# We want to use the inputs of the passed network as the inputs to this
# We want to use the inputs of the passed network as the inputs to this
...
@@ -67,22 +69,36 @@ class BertClassifier(tf.keras.Model):
...
@@ -67,22 +69,36 @@ class BertClassifier(tf.keras.Model):
# when we construct the Model object at the end of init.
# when we construct the Model object at the end of init.
inputs
=
network
.
inputs
inputs
=
network
.
inputs
# Because we have a copy of inputs to create this Model object, we can
if
use_encoder_pooler
:
# invoke the Network object with its own input tensors to start the Model.
# Because we have a copy of inputs to create this Model object, we can
_
,
cls_output
=
network
(
inputs
)
# invoke the Network object with its own input tensors to start the Model.
cls_output
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
dropout_rate
)(
cls_output
)
_
,
cls_output
=
network
(
inputs
)
cls_output
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
dropout_rate
)(
cls_output
)
self
.
classifier
=
networks
.
Classification
(
self
.
classifier
=
networks
.
Classification
(
input_width
=
cls_output
.
shape
[
-
1
],
input_width
=
cls_output
.
shape
[
-
1
],
num_classes
=
num_classes
,
num_classes
=
num_classes
,
initializer
=
initializer
,
initializer
=
initializer
,
output
=
output
,
output
=
'logits'
,
name
=
'classification'
)
name
=
'sentence_prediction'
)
predictions
=
self
.
classifier
(
cls_output
)
predictions
=
self
.
classifier
(
cls_output
)
else
:
sequence_output
,
_
=
network
(
inputs
)
self
.
classifier
=
layers
.
ClassificationHead
(
inner_dim
=
sequence_output
.
shape
[
-
1
],
num_classes
=
num_classes
,
initializer
=
initializer
,
dropout_rate
=
dropout_rate
,
name
=
'sentence_prediction'
)
predictions
=
self
.
classifier
(
sequence_output
)
super
(
BertClassifier
,
self
).
__init__
(
super
(
BertClassifier
,
self
).
__init__
(
inputs
=
inputs
,
outputs
=
predictions
,
**
kwargs
)
inputs
=
inputs
,
outputs
=
predictions
,
**
kwargs
)
@
property
def
checkpoint_items
(
self
):
return
dict
(
encoder
=
self
.
_network
)
def
get_config
(
self
):
def
get_config
(
self
):
return
self
.
_config
return
self
.
_config
...
...
Prev
1
2
3
4
5
6
7
8
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment