Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
5a2cf36f
"convert/convert_nomicbert.go" did not exist on "77903ab8b4fb8075faad7bde5bde2eee3173e407"
Commit
5a2cf36f
authored
Jul 23, 2020
by
Kaushik Shivakumar
Browse files
Merge remote-tracking branch 'upstream/master' into newavarecords
parents
258ddfc3
a829e648
Changes
330
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
246 additions
and
221 deletions
+246
-221
official/nlp/modeling/layers/dense_einsum.py
official/nlp/modeling/layers/dense_einsum.py
+5
-0
official/nlp/modeling/layers/multi_channel_attention.py
official/nlp/modeling/layers/multi_channel_attention.py
+46
-31
official/nlp/modeling/layers/multi_channel_attention_test.py
official/nlp/modeling/layers/multi_channel_attention_test.py
+5
-1
official/nlp/modeling/layers/position_embedding.py
official/nlp/modeling/layers/position_embedding.py
+0
-1
official/nlp/modeling/layers/rezero_transformer.py
official/nlp/modeling/layers/rezero_transformer.py
+22
-31
official/nlp/modeling/layers/talking_heads_attention.py
official/nlp/modeling/layers/talking_heads_attention.py
+7
-7
official/nlp/modeling/layers/talking_heads_attention_test.py
official/nlp/modeling/layers/talking_heads_attention_test.py
+10
-8
official/nlp/modeling/layers/transformer.py
official/nlp/modeling/layers/transformer.py
+62
-85
official/nlp/modeling/layers/transformer_scaffold.py
official/nlp/modeling/layers/transformer_scaffold.py
+2
-3
official/nlp/modeling/layers/transformer_scaffold_test.py
official/nlp/modeling/layers/transformer_scaffold_test.py
+2
-2
official/nlp/modeling/layers/transformer_test.py
official/nlp/modeling/layers/transformer_test.py
+4
-1
official/nlp/modeling/models/bert_classifier.py
official/nlp/modeling/models/bert_classifier.py
+35
-15
official/nlp/modeling/models/bert_classifier_test.py
official/nlp/modeling/models/bert_classifier_test.py
+2
-3
official/nlp/modeling/models/bert_pretrainer.py
official/nlp/modeling/models/bert_pretrainer.py
+15
-19
official/nlp/modeling/models/bert_pretrainer_test.py
official/nlp/modeling/models/bert_pretrainer_test.py
+3
-4
official/nlp/modeling/models/bert_span_labeler.py
official/nlp/modeling/models/bert_span_labeler.py
+4
-1
official/nlp/modeling/models/bert_token_classifier.py
official/nlp/modeling/models/bert_token_classifier.py
+3
-0
official/nlp/modeling/models/electra_pretrainer.py
official/nlp/modeling/models/electra_pretrainer.py
+17
-6
official/nlp/modeling/models/electra_pretrainer_test.py
official/nlp/modeling/models/electra_pretrainer_test.py
+0
-3
official/nlp/modeling/networks/albert_transformer_encoder.py
official/nlp/modeling/networks/albert_transformer_encoder.py
+2
-0
No files found.
official/nlp/modeling/layers/dense_einsum.py
View file @
5a2cf36f
...
@@ -21,6 +21,8 @@ from __future__ import print_function
...
@@ -21,6 +21,8 @@ from __future__ import print_function
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.python.util
import
deprecation
_CHR_IDX
=
[
"a"
,
"b"
,
"c"
,
"d"
,
"e"
,
"f"
,
"g"
,
"h"
,
"i"
,
"j"
,
"k"
,
"l"
,
"m"
]
_CHR_IDX
=
[
"a"
,
"b"
,
"c"
,
"d"
,
"e"
,
"f"
,
"g"
,
"h"
,
"i"
,
"j"
,
"k"
,
"l"
,
"m"
]
...
@@ -57,6 +59,9 @@ class DenseEinsum(tf.keras.layers.Layer):
...
@@ -57,6 +59,9 @@ class DenseEinsum(tf.keras.layers.Layer):
`(batch_size, units)`.
`(batch_size, units)`.
"""
"""
@
deprecation
.
deprecated
(
None
,
"DenseEinsum is deprecated. Please use "
"tf.keras.experimental.EinsumDense layer instead."
)
def
__init__
(
self
,
def
__init__
(
self
,
output_shape
,
output_shape
,
num_summed_dimensions
=
1
,
num_summed_dimensions
=
1
,
...
...
official/nlp/modeling/layers/multi_channel_attention.py
View file @
5a2cf36f
...
@@ -26,7 +26,6 @@ import math
...
@@ -26,7 +26,6 @@ import math
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.modeling
import
tf_utils
from
official.nlp.modeling.layers
import
attention
from
official.nlp.modeling.layers
import
attention
from
official.nlp.modeling.layers
import
dense_einsum
from
official.nlp.modeling.layers
import
masked_softmax
from
official.nlp.modeling.layers
import
masked_softmax
...
@@ -67,28 +66,26 @@ class VotingAttention(tf.keras.layers.Layer):
...
@@ -67,28 +66,26 @@ class VotingAttention(tf.keras.layers.Layer):
self
.
_bias_constraint
=
tf
.
keras
.
constraints
.
get
(
bias_constraint
)
self
.
_bias_constraint
=
tf
.
keras
.
constraints
.
get
(
bias_constraint
)
def
build
(
self
,
unused_input_shapes
):
def
build
(
self
,
unused_input_shapes
):
self
.
_query_dense
=
dense_einsum
.
DenseEinsum
(
common_kwargs
=
dict
(
output_shape
=
(
self
.
_num_heads
,
self
.
_head_size
),
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
bias_constraint
=
self
.
_bias_constraint
)
dtype
=
self
.
dtype
,
self
.
_query_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
name
=
"encdocatt_query"
)
"BAE,ENH->BANH"
,
self
.
_key_dense
=
dense_einsum
.
DenseEinsum
(
output_shape
=
(
None
,
self
.
_num_heads
,
self
.
_head_size
),
output_shape
=
(
self
.
_num_heads
,
self
.
_head_size
),
bias_axes
=
"NH"
,
kernel_initializer
=
self
.
_kernel_initializer
,
name
=
"query"
,
bias_initializer
=
self
.
_bias_initializer
,
**
common_kwargs
)
kernel_regularizer
=
self
.
_kernel_regularizer
,
self
.
_key_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
bias_regularizer
=
self
.
_bias_regularizer
,
"BAE,ENH->BANH"
,
activity_regularizer
=
self
.
_activity_regularizer
,
output_shape
=
(
None
,
self
.
_num_heads
,
self
.
_head_size
),
kernel_constraint
=
self
.
_kernel_constraint
,
bias_axes
=
"NH"
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"key"
,
dtype
=
self
.
dtype
,
**
common_kwargs
)
name
=
"encdocatt_key"
)
super
(
VotingAttention
,
self
).
build
(
unused_input_shapes
)
super
(
VotingAttention
,
self
).
build
(
unused_input_shapes
)
def
call
(
self
,
encoder_outputs
,
doc_attention_mask
):
def
call
(
self
,
encoder_outputs
,
doc_attention_mask
):
...
@@ -113,34 +110,52 @@ class VotingAttention(tf.keras.layers.Layer):
...
@@ -113,34 +110,52 @@ class VotingAttention(tf.keras.layers.Layer):
class
MultiChannelAttention
(
attention
.
MultiHeadAttention
):
class
MultiChannelAttention
(
attention
.
MultiHeadAttention
):
"""Multi-channel Attention layer.
"""Multi-channel Attention layer.
Introduced in: https://arxiv.org/abs/2001.09386. Expects multiple
Introduced in, [Generating Representative Headlines for News Stories
cross-attention target sequences.
](https://arxiv.org/abs/2001.09386). Expects multiple cross-attention
target sequences.
Call args:
query: Query `Tensor` of shape `[B, T, dim]`.
value: Value `Tensor` of shape `[B, A, S, dim]`, where A denotes the
context_attention_weights: Context weights of shape `[B, N, T, A]`, where N
is the number of attention heads. Combines multi-channel sources
context tensors according to the distribution among channels.
key: Optional key `Tensor` of shape `[B, A, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention
to certain positions.
"""
"""
def
_
build_attention
(
self
,
qkv_
rank
):
def
build_attention
(
self
,
rank
):
super
(
MultiChannelAttention
,
self
).
_
build_attention
(
qkv_
rank
)
super
(
MultiChannelAttention
,
self
).
build_attention
(
rank
)
self
.
_masked_softmax
=
masked_softmax
.
MaskedSoftmax
(
mask_expansion_axes
=
[
2
])
self
.
_masked_softmax
=
masked_softmax
.
MaskedSoftmax
(
mask_expansion_axes
=
[
2
])
def
call
(
self
,
inputs
,
attention_mask
=
None
):
def
call
(
self
,
from_tensor
=
inputs
[
0
]
query
,
to_tensor
=
inputs
[
1
]
value
,
doc_attention_probs
=
inputs
[
2
]
key
=
None
,
context_attention_weights
=
None
,
attention_mask
=
None
):
if
not
self
.
_built_from_signature
:
self
.
_build_from_signature
(
query
,
value
,
key
=
key
)
if
key
is
None
:
key
=
value
# Scalar dimensions referenced here:
# Scalar dimensions referenced here:
# B = batch size (number of stories)
# B = batch size (number of stories)
# A = num_docs (number of docs)
# A = num_docs (number of docs)
# F =
`from_tensor`
sequence length
# F =
target
sequence length
# T =
`to_tensor`
sequence length
# T =
source
sequence length
# N = `num_attention_heads`
# N = `num_attention_heads`
# H = `size_per_head`
# H = `size_per_head`
# `query_tensor` = [B, F, N ,H]
# `query_tensor` = [B, F, N ,H]
query_tensor
=
self
.
_query_dense
(
from_tensor
)
query_tensor
=
self
.
_query_dense
(
query
)
# `key_tensor` = [B, A, T, N, H]
# `key_tensor` = [B, A, T, N, H]
key_tensor
=
self
.
_key_dense
(
to_tensor
)
key_tensor
=
self
.
_key_dense
(
key
)
# `value_tensor` = [B, A, T, N, H]
# `value_tensor` = [B, A, T, N, H]
value_tensor
=
self
.
_value_dense
(
to_tensor
)
value_tensor
=
self
.
_value_dense
(
value
)
# Take the dot product between "query" and "key" to get the raw
# Take the dot product between "query" and "key" to get the raw
# attention scores.
# attention scores.
...
@@ -159,7 +174,7 @@ class MultiChannelAttention(attention.MultiHeadAttention):
...
@@ -159,7 +174,7 @@ class MultiChannelAttention(attention.MultiHeadAttention):
# `context_layer` = [B, F, N, H]
# `context_layer` = [B, F, N, H]
context_layer
=
tf
.
einsum
(
"BANFT,BATNH->BAFNH"
,
attention_probs
,
context_layer
=
tf
.
einsum
(
"BANFT,BATNH->BAFNH"
,
attention_probs
,
value_tensor
)
value_tensor
)
attention_output
=
tf
.
einsum
(
"BNFA,BAFNH->BFNH"
,
doc
_attention_
prob
s
,
attention_output
=
tf
.
einsum
(
"BNFA,BAFNH->BFNH"
,
context
_attention_
weight
s
,
context_layer
)
context_layer
)
attention_output
=
self
.
_output_dense
(
attention_output
)
attention_output
=
self
.
_output_dense
(
attention_output
)
return
attention_output
return
attention_output
official/nlp/modeling/layers/multi_channel_attention_test.py
View file @
5a2cf36f
...
@@ -48,7 +48,11 @@ class MultiChannelAttentionTest(tf.test.TestCase):
...
@@ -48,7 +48,11 @@ class MultiChannelAttentionTest(tf.test.TestCase):
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
3
,
num_docs
,
4
,
2
))
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
3
,
num_docs
,
4
,
2
))
doc_probs
=
np
.
random
.
randint
(
doc_probs
=
np
.
random
.
randint
(
2
,
size
=
(
3
,
num_heads
,
4
,
num_docs
)).
astype
(
float
)
2
,
size
=
(
3
,
num_heads
,
4
,
num_docs
)).
astype
(
float
)
outputs
=
attention_layer
([
from_data
,
to_data
,
doc_probs
],
mask_data
)
outputs
=
attention_layer
(
query
=
from_data
,
value
=
to_data
,
context_attention_weights
=
doc_probs
,
attention_mask
=
mask_data
)
self
.
assertEqual
(
outputs
.
shape
,
(
3
,
4
,
8
))
self
.
assertEqual
(
outputs
.
shape
,
(
3
,
4
,
8
))
...
...
official/nlp/modeling/layers/position_embedding.py
View file @
5a2cf36f
...
@@ -160,7 +160,6 @@ class RelativePositionEmbedding(tf.keras.layers.Layer):
...
@@ -160,7 +160,6 @@ class RelativePositionEmbedding(tf.keras.layers.Layer):
"hidden_size"
:
self
.
_hidden_size
,
"hidden_size"
:
self
.
_hidden_size
,
"min_timescale"
:
self
.
_min_timescale
,
"min_timescale"
:
self
.
_min_timescale
,
"max_timescale"
:
self
.
_max_timescale
,
"max_timescale"
:
self
.
_max_timescale
,
"length"
:
self
.
_length
,
}
}
base_config
=
super
(
RelativePositionEmbedding
,
self
).
get_config
()
base_config
=
super
(
RelativePositionEmbedding
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
...
...
official/nlp/modeling/layers/rezero_transformer.py
View file @
5a2cf36f
...
@@ -23,7 +23,6 @@ import gin
...
@@ -23,7 +23,6 @@ import gin
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.nlp.modeling.layers
import
attention
from
official.nlp.modeling.layers
import
attention
from
official.nlp.modeling.layers
import
dense_einsum
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
...
@@ -109,19 +108,20 @@ class ReZeroTransformer(tf.keras.layers.Layer):
...
@@ -109,19 +108,20 @@ class ReZeroTransformer(tf.keras.layers.Layer):
"The input size (%d) is not a multiple of the number of attention "
"The input size (%d) is not a multiple of the number of attention "
"heads (%d)"
%
(
hidden_size
,
self
.
_num_heads
))
"heads (%d)"
%
(
hidden_size
,
self
.
_num_heads
))
self
.
_attention_head_size
=
int
(
hidden_size
//
self
.
_num_heads
)
self
.
_attention_head_size
=
int
(
hidden_size
//
self
.
_num_heads
)
common_kwargs
=
dict
(
self
.
_attention_layer
=
attention
.
MultiHeadAttention
(
num_heads
=
self
.
_num_heads
,
key_size
=
self
.
_attention_head_size
,
dropout
=
self
.
_attention_dropout_rate
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
bias_constraint
=
self
.
_bias_constraint
)
name
=
"self_attention"
)
self
.
_attention_layer
=
attention
.
MultiHeadAttention
(
num_heads
=
self
.
_num_heads
,
key_size
=
self
.
_attention_head_size
,
dropout
=
self
.
_attention_dropout_rate
,
name
=
"self_attention"
,
**
common_kwargs
)
self
.
_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
self
.
_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
if
self
.
_use_layer_norm
:
if
self
.
_use_layer_norm
:
# Use float32 in layernorm for numeric stability.
# Use float32 in layernorm for numeric stability.
...
@@ -132,17 +132,12 @@ class ReZeroTransformer(tf.keras.layers.Layer):
...
@@ -132,17 +132,12 @@ class ReZeroTransformer(tf.keras.layers.Layer):
axis
=-
1
,
axis
=-
1
,
epsilon
=
1e-12
,
epsilon
=
1e-12
,
dtype
=
tf
.
float32
))
dtype
=
tf
.
float32
))
self
.
_intermediate_dense
=
dense_einsum
.
DenseEinsum
(
self
.
_intermediate_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
output_shape
=
self
.
_intermediate_size
,
"abc,cd->abd"
,
activation
=
None
,
output_shape
=
(
None
,
self
.
_intermediate_size
),
kernel_initializer
=
self
.
_kernel_initializer
,
bias_axes
=
"d"
,
bias_initializer
=
self
.
_bias_initializer
,
name
=
"intermediate"
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
**
common_kwargs
)
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"intermediate"
)
policy
=
tf
.
keras
.
mixed_precision
.
experimental
.
global_policy
()
policy
=
tf
.
keras
.
mixed_precision
.
experimental
.
global_policy
()
if
policy
.
name
==
"mixed_bfloat16"
:
if
policy
.
name
==
"mixed_bfloat16"
:
# bfloat16 causes BERT with the LAMB optimizer to not converge
# bfloat16 causes BERT with the LAMB optimizer to not converge
...
@@ -151,16 +146,12 @@ class ReZeroTransformer(tf.keras.layers.Layer):
...
@@ -151,16 +146,12 @@ class ReZeroTransformer(tf.keras.layers.Layer):
policy
=
tf
.
float32
policy
=
tf
.
float32
self
.
_intermediate_activation_layer
=
tf
.
keras
.
layers
.
Activation
(
self
.
_intermediate_activation_layer
=
tf
.
keras
.
layers
.
Activation
(
self
.
_intermediate_activation
,
dtype
=
policy
)
self
.
_intermediate_activation
,
dtype
=
policy
)
self
.
_output_dense
=
dense_einsum
.
DenseEinsum
(
self
.
_output_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
output_shape
=
hidden_size
,
"abc,cd->abd"
,
kernel_initializer
=
self
.
_kernel_initializer
,
output_shape
=
(
None
,
hidden_size
),
bias_initializer
=
self
.
_bias_initializer
,
bias_axes
=
"d"
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
name
=
"output"
,
bias_regularizer
=
self
.
_bias_regularizer
,
**
common_kwargs
)
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"output"
)
self
.
_output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
self
.
_output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
if
self
.
_use_layer_norm
:
if
self
.
_use_layer_norm
:
# Use float32 in layernorm for numeric stability.
# Use float32 in layernorm for numeric stability.
...
@@ -222,9 +213,9 @@ class ReZeroTransformer(tf.keras.layers.Layer):
...
@@ -222,9 +213,9 @@ class ReZeroTransformer(tf.keras.layers.Layer):
attention_mask
=
attention_mask
[:,
0
:
self
.
_output_range
,
:]
attention_mask
=
attention_mask
[:,
0
:
self
.
_output_range
,
:]
else
:
else
:
target_tensor
=
input_tensor
target_tensor
=
input_tensor
attention_inputs
=
[
target_tensor
,
input_tensor
]
attention_output
=
self
.
_attention_layer
(
attention_inputs
,
attention_mask
)
attention_output
=
self
.
_attention_layer
(
query
=
target_tensor
,
value
=
input_tensor
,
attention_mask
=
attention_mask
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
attention_output
=
target_tensor
+
self
.
_rezero_a
*
attention_output
attention_output
=
target_tensor
+
self
.
_rezero_a
*
attention_output
if
self
.
_use_layer_norm
:
if
self
.
_use_layer_norm
:
...
...
official/nlp/modeling/layers/talking_heads_attention.py
View file @
5a2cf36f
...
@@ -58,7 +58,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
...
@@ -58,7 +58,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
bias_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
"""
"""
def
_
build_attention
(
self
,
qkv_rank
):
def
build_attention
(
self
,
qkv_rank
):
"""Builds multi-head dot-product attention computations.
"""Builds multi-head dot-product attention computations.
This function overrides base class to create additional linear projection
This function overrides base class to create additional linear projection
...
@@ -67,7 +67,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
...
@@ -67,7 +67,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
Args:
Args:
qkv_rank: the rank of query, key, value tensors after projection.
qkv_rank: the rank of query, key, value tensors after projection.
"""
"""
super
(
TalkingHeadsAttention
,
self
).
_
build_attention
(
qkv_rank
)
super
(
TalkingHeadsAttention
,
self
).
build_attention
(
qkv_rank
)
# Build an equation:
# Build an equation:
# (<batch_dims>, num_heads_a, ...),(num_heads_a, num_heads_b) ->
# (<batch_dims>, num_heads_a, ...),(num_heads_a, num_heads_b) ->
...
@@ -103,11 +103,11 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
...
@@ -103,11 +103,11 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
trainable
=
True
)
trainable
=
True
)
def
_
compute_attention
(
self
,
def
compute_attention
(
self
,
query_tensor
,
query_tensor
,
key_tensor
,
key_tensor
,
value_tensor
,
value_tensor
,
attention_mask
=
None
):
attention_mask
=
None
):
"""Applies Dot-product attention with query, key, value tensors.
"""Applies Dot-product attention with query, key, value tensors.
This function overrides base class to apply additional linear projection
This function overrides base class to apply additional linear projection
...
...
official/nlp/modeling/layers/talking_heads_attention_test.py
View file @
5a2cf36f
...
@@ -46,7 +46,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
...
@@ -46,7 +46,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
# Create a 3-dimensional input (the first dimension is implicit).
# Create a 3-dimensional input (the first dimension is implicit).
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
value
=
tf
.
keras
.
Input
(
shape
=
(
20
,
80
))
value
=
tf
.
keras
.
Input
(
shape
=
(
20
,
80
))
output
=
test_layer
(
[
query
,
value
]
)
output
=
test_layer
(
query
=
query
,
value
=
value
)
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
]
+
output_dims
)
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
]
+
output_dims
)
def
test_non_masked_self_attention
(
self
):
def
test_non_masked_self_attention
(
self
):
...
@@ -55,7 +55,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
...
@@ -55,7 +55,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
num_heads
=
12
,
key_size
=
64
)
num_heads
=
12
,
key_size
=
64
)
# Create a 3-dimensional input (the first dimension is implicit).
# Create a 3-dimensional input (the first dimension is implicit).
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
output
=
test_layer
(
[
query
,
query
]
)
output
=
test_layer
(
query
=
query
,
value
=
query
)
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
def
test_attention_scores
(
self
):
def
test_attention_scores
(
self
):
...
@@ -64,7 +64,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
...
@@ -64,7 +64,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
num_heads
=
12
,
key_size
=
64
,
return_attention_scores
=
True
)
num_heads
=
12
,
key_size
=
64
,
return_attention_scores
=
True
)
# Create a 3-dimensional input (the first dimension is implicit).
# Create a 3-dimensional input (the first dimension is implicit).
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
output
,
coef
=
test_layer
(
[
query
,
query
]
)
output
,
coef
=
test_layer
(
query
=
query
,
value
=
query
)
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
self
.
assertEqual
(
coef
.
shape
.
as_list
(),
[
None
,
12
,
40
,
40
])
self
.
assertEqual
(
coef
.
shape
.
as_list
(),
[
None
,
12
,
40
,
40
])
...
@@ -78,7 +78,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
...
@@ -78,7 +78,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
query
=
tf
.
keras
.
Input
(
shape
=
(
4
,
8
))
query
=
tf
.
keras
.
Input
(
shape
=
(
4
,
8
))
value
=
tf
.
keras
.
Input
(
shape
=
(
2
,
8
))
value
=
tf
.
keras
.
Input
(
shape
=
(
2
,
8
))
mask_tensor
=
tf
.
keras
.
Input
(
shape
=
(
4
,
2
))
mask_tensor
=
tf
.
keras
.
Input
(
shape
=
(
4
,
2
))
output
=
test_layer
(
[
query
,
value
],
mask_tensor
)
output
=
test_layer
(
query
=
query
,
value
=
value
,
attention_mask
=
mask_tensor
)
# Create a model containing the test layer.
# Create a model containing the test layer.
model
=
tf
.
keras
.
Model
([
query
,
value
,
mask_tensor
],
output
)
model
=
tf
.
keras
.
Model
([
query
,
value
,
mask_tensor
],
output
)
...
@@ -102,7 +102,8 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
...
@@ -102,7 +102,8 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
# Tests the layer with three inputs: Q, K, V.
# Tests the layer with three inputs: Q, K, V.
key
=
tf
.
keras
.
Input
(
shape
=
(
2
,
8
))
key
=
tf
.
keras
.
Input
(
shape
=
(
2
,
8
))
output
=
test_layer
([
query
,
value
,
key
],
mask_tensor
)
output
=
test_layer
(
query
=
query
,
value
=
value
,
key
=
key
,
attention_mask
=
mask_tensor
)
model
=
tf
.
keras
.
Model
([
query
,
value
,
key
,
mask_tensor
],
output
)
model
=
tf
.
keras
.
Model
([
query
,
value
,
key
,
mask_tensor
],
output
)
masked_output_data
=
model
.
predict
([
from_data
,
to_data
,
to_data
,
mask_data
])
masked_output_data
=
model
.
predict
([
from_data
,
to_data
,
to_data
,
mask_data
])
...
@@ -127,7 +128,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
...
@@ -127,7 +128,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
))
kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
))
# Create a 3-dimensional input (the first dimension is implicit).
# Create a 3-dimensional input (the first dimension is implicit).
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
query
=
tf
.
keras
.
Input
(
shape
=
(
40
,
80
))
output
=
test_layer
(
[
query
,
query
]
)
output
=
test_layer
(
query
=
query
,
value
=
query
)
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
self
.
assertEqual
(
output
.
shape
.
as_list
(),
[
None
,
40
,
80
])
@
parameterized
.
named_parameters
(
@
parameterized
.
named_parameters
(
...
@@ -149,11 +150,12 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
...
@@ -149,11 +150,12 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
# Invoke the data with a random set of mask data. This should mask at least
# Invoke the data with a random set of mask data. This should mask at least
# one element.
# one element.
mask_data
=
np
.
random
.
randint
(
2
,
size
=
mask_shape
).
astype
(
"bool"
)
mask_data
=
np
.
random
.
randint
(
2
,
size
=
mask_shape
).
astype
(
"bool"
)
output
=
test_layer
(
[
query
,
value
],
mask_data
)
output
=
test_layer
(
query
=
query
,
value
=
value
,
attention_mask
=
mask_data
)
# Invoke the same data, but with a null mask (where no elements are masked).
# Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data
=
np
.
ones
(
mask_shape
)
null_mask_data
=
np
.
ones
(
mask_shape
)
unmasked_output
=
test_layer
([
query
,
value
],
null_mask_data
)
unmasked_output
=
test_layer
(
query
=
query
,
value
=
value
,
attention_mask
=
null_mask_data
)
# Because one data is masked and one is not, the outputs should not be the
# Because one data is masked and one is not, the outputs should not be the
# same.
# same.
self
.
assertNotAllClose
(
output
,
unmasked_output
)
self
.
assertNotAllClose
(
output
,
unmasked_output
)
...
...
official/nlp/modeling/layers/transformer.py
View file @
5a2cf36f
...
@@ -23,7 +23,6 @@ import gin
...
@@ -23,7 +23,6 @@ import gin
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.nlp.modeling.layers
import
attention
from
official.nlp.modeling.layers
import
attention
from
official.nlp.modeling.layers
import
dense_einsum
from
official.nlp.modeling.layers
import
multi_channel_attention
from
official.nlp.modeling.layers
import
multi_channel_attention
from
official.nlp.modeling.layers.util
import
tf_function_if_eager
from
official.nlp.modeling.layers.util
import
tf_function_if_eager
...
@@ -106,21 +105,24 @@ class Transformer(tf.keras.layers.Layer):
...
@@ -106,21 +105,24 @@ class Transformer(tf.keras.layers.Layer):
"The input size (%d) is not a multiple of the number of attention "
"The input size (%d) is not a multiple of the number of attention "
"heads (%d)"
%
(
hidden_size
,
self
.
_num_heads
))
"heads (%d)"
%
(
hidden_size
,
self
.
_num_heads
))
self
.
_attention_head_size
=
int
(
hidden_size
//
self
.
_num_heads
)
self
.
_attention_head_size
=
int
(
hidden_size
//
self
.
_num_heads
)
common_kwargs
=
dict
(
self
.
_attention_layer
=
attention
.
MultiHeadAttention
(
num_heads
=
self
.
_num_heads
,
key_size
=
self
.
_attention_head_size
,
dropout
=
self
.
_attention_dropout_rate
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
bias_constraint
=
self
.
_bias_constraint
)
name
=
"self_attention"
)
self
.
_attention_layer
=
attention
.
MultiHeadAttention
(
num_heads
=
self
.
_num_heads
,
key_size
=
self
.
_attention_head_size
,
dropout
=
self
.
_attention_dropout_rate
,
name
=
"self_attention"
,
**
common_kwargs
)
# pylint: disable=protected-access
# pylint: disable=protected-access
self
.
_attention_layer
.
build
([
input_tensor_shape
]
*
3
)
# Temporarily handling for checkpoint compatible changes.
self
.
_attention_layer
.
_build_from_signature
(
query
=
input_tensor_shape
,
value
=
input_tensor_shape
)
self
.
_attention_output_dense
=
self
.
_attention_layer
.
_output_dense
self
.
_attention_output_dense
=
self
.
_attention_layer
.
_output_dense
# pylint: enable=protected-access
# pylint: enable=protected-access
self
.
_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
self
.
_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
...
@@ -132,17 +134,12 @@ class Transformer(tf.keras.layers.Layer):
...
@@ -132,17 +134,12 @@ class Transformer(tf.keras.layers.Layer):
axis
=-
1
,
axis
=-
1
,
epsilon
=
1e-12
,
epsilon
=
1e-12
,
dtype
=
tf
.
float32
))
dtype
=
tf
.
float32
))
self
.
_intermediate_dense
=
dense_einsum
.
DenseEinsum
(
self
.
_intermediate_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
output_shape
=
self
.
_intermediate_size
,
"abc,cd->abd"
,
activation
=
None
,
output_shape
=
(
None
,
self
.
_intermediate_size
),
kernel_initializer
=
self
.
_kernel_initializer
,
bias_axes
=
"d"
,
bias_initializer
=
self
.
_bias_initializer
,
name
=
"intermediate"
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
**
common_kwargs
)
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"intermediate"
)
policy
=
tf
.
keras
.
mixed_precision
.
experimental
.
global_policy
()
policy
=
tf
.
keras
.
mixed_precision
.
experimental
.
global_policy
()
if
policy
.
name
==
"mixed_bfloat16"
:
if
policy
.
name
==
"mixed_bfloat16"
:
# bfloat16 causes BERT with the LAMB optimizer to not converge
# bfloat16 causes BERT with the LAMB optimizer to not converge
...
@@ -151,16 +148,12 @@ class Transformer(tf.keras.layers.Layer):
...
@@ -151,16 +148,12 @@ class Transformer(tf.keras.layers.Layer):
policy
=
tf
.
float32
policy
=
tf
.
float32
self
.
_intermediate_activation_layer
=
tf
.
keras
.
layers
.
Activation
(
self
.
_intermediate_activation_layer
=
tf
.
keras
.
layers
.
Activation
(
self
.
_intermediate_activation
,
dtype
=
policy
)
self
.
_intermediate_activation
,
dtype
=
policy
)
self
.
_output_dense
=
dense_einsum
.
DenseEinsum
(
self
.
_output_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
output_shape
=
hidden_size
,
"abc,cd->abd"
,
kernel_initializer
=
self
.
_kernel_initializer
,
output_shape
=
(
None
,
hidden_size
),
bias_initializer
=
self
.
_bias_initializer
,
bias_axes
=
"d"
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
name
=
"output"
,
bias_regularizer
=
self
.
_bias_regularizer
,
**
common_kwargs
)
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"output"
)
self
.
_output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
self
.
_output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
# Use float32 in layernorm for numeric stability.
# Use float32 in layernorm for numeric stability.
self
.
_output_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
self
.
_output_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
...
@@ -211,9 +204,9 @@ class Transformer(tf.keras.layers.Layer):
...
@@ -211,9 +204,9 @@ class Transformer(tf.keras.layers.Layer):
attention_mask
=
attention_mask
[:,
0
:
self
.
_output_range
,
:]
attention_mask
=
attention_mask
[:,
0
:
self
.
_output_range
,
:]
else
:
else
:
target_tensor
=
input_tensor
target_tensor
=
input_tensor
attention_inputs
=
[
target_tensor
,
input_tensor
]
attention_output
=
self
.
_attention_layer
(
attention_inputs
,
attention_mask
)
attention_output
=
self
.
_attention_layer
(
query
=
target_tensor
,
value
=
input_tensor
,
attention_mask
=
attention_mask
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
attention_output
=
self
.
_attention_layer_norm
(
target_tensor
+
attention_output
=
self
.
_attention_layer_norm
(
target_tensor
+
attention_output
)
attention_output
)
...
@@ -312,30 +305,27 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
...
@@ -312,30 +305,27 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
"The hidden size (%d) is not a multiple of the number of attention "
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)"
%
(
hidden_size
,
self
.
num_attention_heads
))
"heads (%d)"
%
(
hidden_size
,
self
.
num_attention_heads
))
self
.
attention_head_size
=
int
(
hidden_size
/
self
.
num_attention_heads
)
self
.
attention_head_size
=
int
(
hidden_size
/
self
.
num_attention_heads
)
# Self attention.
common_kwargs
=
dict
(
self
.
self_attention
=
attention
.
CachedAttention
(
num_heads
=
self
.
num_attention_heads
,
key_size
=
self
.
attention_head_size
,
dropout
=
self
.
attention_dropout_rate
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
bias_constraint
=
self
.
_bias_constraint
)
name
=
"self_attention"
)
# Self attention.
self
.
self_attention_output_dense
=
dense_einsum
.
DenseEinsum
(
self
.
self_attention
=
attention
.
CachedAttention
(
output_shape
=
hidden_size
,
num_heads
=
self
.
num_attention_heads
,
num_summed_dimensions
=
2
,
key_size
=
self
.
attention_head_size
,
kernel_initializer
=
self
.
_kernel_initializer
,
dropout
=
self
.
attention_dropout_rate
,
bias_initializer
=
self
.
_bias_initializer
,
name
=
"self_attention"
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
**
common_kwargs
)
bias_regularizer
=
self
.
_bias_regularizer
,
self
.
self_attention_output_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
activity_regularizer
=
self
.
_activity_regularizer
,
"abc,cd->abd"
,
kernel_constraint
=
self
.
_kernel_constraint
,
output_shape
=
(
None
,
hidden_size
),
bias_constraint
=
self
.
_bias_constraint
,
bias_axes
=
"d"
,
name
=
"self_attention_output"
)
name
=
"output"
,
**
common_kwargs
)
self
.
self_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
self
.
self_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
dropout_rate
)
rate
=
self
.
dropout_rate
)
self
.
self_attention_layer_norm
=
(
self
.
self_attention_layer_norm
=
(
...
@@ -347,14 +337,8 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
...
@@ -347,14 +337,8 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
key_size
=
self
.
attention_head_size
,
key_size
=
self
.
attention_head_size
,
dropout
=
self
.
attention_dropout_rate
,
dropout
=
self
.
attention_dropout_rate
,
output_shape
=
hidden_size
,
output_shape
=
hidden_size
,
kernel_initializer
=
self
.
_kernel_initializer
,
name
=
"attention/encdec"
,
bias_initializer
=
self
.
_bias_initializer
,
**
common_kwargs
)
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"attention/encdec"
)
self
.
encdec_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
self
.
encdec_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
dropout_rate
)
rate
=
self
.
dropout_rate
)
...
@@ -363,29 +347,20 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
...
@@ -363,29 +347,20 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
name
=
"attention/encdec_output_layer_norm"
,
axis
=-
1
,
epsilon
=
1e-12
))
name
=
"attention/encdec_output_layer_norm"
,
axis
=-
1
,
epsilon
=
1e-12
))
# Feed-forward projection.
# Feed-forward projection.
self
.
intermediate_dense
=
dense_einsum
.
DenseEinsum
(
self
.
intermediate_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
output_shape
=
self
.
intermediate_size
,
"abc,cd->abd"
,
activation
=
None
,
output_shape
=
(
None
,
self
.
intermediate_size
),
kernel_initializer
=
self
.
_kernel_initializer
,
bias_axes
=
"d"
,
bias_initializer
=
self
.
_bias_initializer
,
name
=
"intermediate"
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
**
common_kwargs
)
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"intermediate"
)
self
.
intermediate_activation_layer
=
tf
.
keras
.
layers
.
Activation
(
self
.
intermediate_activation_layer
=
tf
.
keras
.
layers
.
Activation
(
self
.
intermediate_activation
)
self
.
intermediate_activation
)
self
.
output_dense
=
dense_einsum
.
DenseEinsum
(
self
.
output_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
output_shape
=
hidden_size
,
"abc,cd->abd"
,
kernel_initializer
=
self
.
_kernel_initializer
,
output_shape
=
(
None
,
hidden_size
),
bias_initializer
=
self
.
_bias_initializer
,
bias_axes
=
"d"
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
name
=
"output"
,
bias_regularizer
=
self
.
_bias_regularizer
,
**
common_kwargs
)
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
,
name
=
"output"
)
self
.
output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
dropout_rate
)
self
.
output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
dropout_rate
)
self
.
output_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
self
.
output_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
"output_layer_norm"
,
axis
=-
1
,
epsilon
=
1e-12
)
name
=
"output_layer_norm"
,
axis
=-
1
,
epsilon
=
1e-12
)
...
@@ -409,21 +384,23 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
...
@@ -409,21 +384,23 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
"TransformerDecoderLayer must have 4 inputs, but it got: %d"
%
"TransformerDecoderLayer must have 4 inputs, but it got: %d"
%
len
(
inputs
))
len
(
inputs
))
input_tensor
,
memory
,
attention_mask
,
self_attention_mask
=
inputs
[:
4
]
input_tensor
,
memory
,
attention_mask
,
self_attention_mask
=
inputs
[:
4
]
self_attention_inputs
=
[
input_tensor
,
input_tensor
]
self_attention_output
,
cache
=
self
.
self_attention
(
self_attention_output
,
cache
=
self
.
self_attention
(
self_attention_inputs
,
query
=
input_tensor
,
value
=
input_tensor
,
attention_mask
=
self_attention_mask
,
attention_mask
=
self_attention_mask
,
cache
=
cache
,
cache
=
cache
,
decode_loop_step
=
decode_loop_step
)
decode_loop_step
=
decode_loop_step
)
self_attention_output
=
self
.
self_attention_dropout
(
self_attention_output
)
self_attention_output
=
self
.
self_attention_dropout
(
self_attention_output
)
self_attention_output
=
self
.
self_attention_layer_norm
(
self_attention_output
=
self
.
self_attention_layer_norm
(
input_tensor
+
self_attention_output
)
input_tensor
+
self_attention_output
)
cross_attn_inputs
=
dict
(
cross_attn_inputs
=
[
self_attention_output
,
memory
]
query
=
self_attention_output
,
value
=
memory
,
attention_mask
=
attention_mask
)
if
self
.
multi_channel_cross_attention
:
if
self
.
multi_channel_cross_attention
:
# Accesses the 5-th input tensor for the doc-attention probabilities.
# Accesses the 5-th input tensor for the doc-attention probabilities.
cross_attn_inputs
.
append
(
inputs
[
-
1
]
)
cross_attn_inputs
[
"context_attention_weights"
]
=
inputs
[
-
1
]
attention_output
=
self
.
encdec_attention
(
cross_attn_inputs
,
attention_mask
)
attention_output
=
self
.
encdec_attention
(
**
cross_attn_inputs
)
attention_output
=
self
.
encdec_attention_dropout
(
attention_output
)
attention_output
=
self
.
encdec_attention_dropout
(
attention_output
)
attention_output
=
self
.
encdec_attention_layer_norm
(
self_attention_output
+
attention_output
=
self
.
encdec_attention_layer_norm
(
self_attention_output
+
attention_output
)
attention_output
)
...
...
official/nlp/modeling/layers/transformer_scaffold.py
View file @
5a2cf36f
...
@@ -262,9 +262,8 @@ class TransformerScaffold(tf.keras.layers.Layer):
...
@@ -262,9 +262,8 @@ class TransformerScaffold(tf.keras.layers.Layer):
else
:
else
:
input_tensor
,
attention_mask
=
(
inputs
,
None
)
input_tensor
,
attention_mask
=
(
inputs
,
None
)
attention_inputs
=
[
input_tensor
,
input_tensor
]
attention_output
=
self
.
_attention_layer
(
query
=
input_tensor
,
value
=
input_tensor
,
attention_mask
=
attention_mask
)
attention_output
=
self
.
_attention_layer
(
attention_inputs
,
attention_mask
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
attention_output
=
self
.
_attention_layer_norm
(
input_tensor
+
attention_output
=
self
.
_attention_layer_norm
(
input_tensor
+
attention_output
)
attention_output
)
...
...
official/nlp/modeling/layers/transformer_scaffold_test.py
View file @
5a2cf36f
...
@@ -39,10 +39,10 @@ class ValidatedAttentionLayer(attention.MultiHeadAttention):
...
@@ -39,10 +39,10 @@ class ValidatedAttentionLayer(attention.MultiHeadAttention):
super
(
ValidatedAttentionLayer
,
self
).
__init__
(
**
kwargs
)
super
(
ValidatedAttentionLayer
,
self
).
__init__
(
**
kwargs
)
self
.
list
=
call_list
self
.
list
=
call_list
def
call
(
self
,
inputs
,
attention_mask
=
None
):
def
call
(
self
,
query
,
value
,
attention_mask
=
None
):
self
.
list
.
append
(
True
)
self
.
list
.
append
(
True
)
return
super
(
ValidatedAttentionLayer
,
self
).
call
(
return
super
(
ValidatedAttentionLayer
,
self
).
call
(
inputs
,
attention_mask
=
attention_mask
)
query
,
value
,
attention_mask
=
attention_mask
)
def
get_config
(
self
):
def
get_config
(
self
):
config
=
super
(
ValidatedAttentionLayer
,
self
).
get_config
()
config
=
super
(
ValidatedAttentionLayer
,
self
).
get_config
()
...
...
official/nlp/modeling/layers/transformer_test.py
View file @
5a2cf36f
...
@@ -152,7 +152,10 @@ class TransformerLayerTest(keras_parameterized.TestCase):
...
@@ -152,7 +152,10 @@ class TransformerLayerTest(keras_parameterized.TestCase):
_
=
new_layer
([
input_data
,
mask_data
])
_
=
new_layer
([
input_data
,
mask_data
])
new_layer
.
set_weights
(
test_layer
.
get_weights
())
new_layer
.
set_weights
(
test_layer
.
get_weights
())
new_output_tensor
=
new_layer
([
input_data
,
mask_data
])
new_output_tensor
=
new_layer
([
input_data
,
mask_data
])
self
.
assertAllClose
(
new_output_tensor
,
output_tensor
[:,
0
:
1
,
:])
self
.
assertAllClose
(
new_output_tensor
,
output_tensor
[:,
0
:
1
,
:],
atol
=
5e-5
,
rtol
=
0.003
)
def
test_layer_invocation_with_float16_dtype
(
self
,
transformer_cls
):
def
test_layer_invocation_with_float16_dtype
(
self
,
transformer_cls
):
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
'mixed_float16'
)
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
'mixed_float16'
)
...
...
official/nlp/modeling/models/bert_classifier.py
View file @
5a2cf36f
...
@@ -21,6 +21,7 @@ from __future__ import print_function
...
@@ -21,6 +21,7 @@ from __future__ import print_function
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.nlp.modeling
import
layers
from
official.nlp.modeling
import
networks
from
official.nlp.modeling
import
networks
...
@@ -36,6 +37,9 @@ class BertClassifier(tf.keras.Model):
...
@@ -36,6 +37,9 @@ class BertClassifier(tf.keras.Model):
instantiates a classification network based on the passed `num_classes`
instantiates a classification network based on the passed `num_classes`
argument. If `num_classes` is set to 1, a regression network is instantiated.
argument. If `num_classes` is set to 1, a regression network is instantiated.
*Note* that the model is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments:
Arguments:
network: A transformer network. This network should output a sequence output
network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding
and a classification output. Furthermore, it should expose its embedding
...
@@ -43,23 +47,25 @@ class BertClassifier(tf.keras.Model):
...
@@ -43,23 +47,25 @@ class BertClassifier(tf.keras.Model):
num_classes: Number of classes to predict from the classification network.
num_classes: Number of classes to predict from the classification network.
initializer: The initializer (if any) to use in the classification networks.
initializer: The initializer (if any) to use in the classification networks.
Defaults to a Glorot uniform initializer.
Defaults to a Glorot uniform initializer.
output: The output style for this network. Can be either 'logits' or
dropout_rate: The dropout probability of the cls head.
'predictions'.
use_encoder_pooler: Whether to use the pooler layer pre-defined inside
the encoder.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
network
,
network
,
num_classes
,
num_classes
,
initializer
=
'glorot_uniform'
,
initializer
=
'glorot_uniform'
,
output
=
'logits'
,
dropout_rate
=
0.1
,
dropout_rate
=
0.1
,
use_encoder_pooler
=
True
,
**
kwargs
):
**
kwargs
):
self
.
_self_setattr_tracking
=
False
self
.
_self_setattr_tracking
=
False
self
.
_network
=
network
self
.
_config
=
{
self
.
_config
=
{
'network'
:
network
,
'network'
:
network
,
'num_classes'
:
num_classes
,
'num_classes'
:
num_classes
,
'initializer'
:
initializer
,
'initializer'
:
initializer
,
'
output'
:
output
,
'
use_encoder_pooler'
:
use_encoder_pooler
,
}
}
# We want to use the inputs of the passed network as the inputs to this
# We want to use the inputs of the passed network as the inputs to this
...
@@ -67,22 +73,36 @@ class BertClassifier(tf.keras.Model):
...
@@ -67,22 +73,36 @@ class BertClassifier(tf.keras.Model):
# when we construct the Model object at the end of init.
# when we construct the Model object at the end of init.
inputs
=
network
.
inputs
inputs
=
network
.
inputs
# Because we have a copy of inputs to create this Model object, we can
if
use_encoder_pooler
:
# invoke the Network object with its own input tensors to start the Model.
# Because we have a copy of inputs to create this Model object, we can
_
,
cls_output
=
network
(
inputs
)
# invoke the Network object with its own input tensors to start the Model.
cls_output
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
dropout_rate
)(
cls_output
)
_
,
cls_output
=
network
(
inputs
)
cls_output
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
dropout_rate
)(
cls_output
)
self
.
classifier
=
networks
.
Classification
(
self
.
classifier
=
networks
.
Classification
(
input_width
=
cls_output
.
shape
[
-
1
],
input_width
=
cls_output
.
shape
[
-
1
],
num_classes
=
num_classes
,
num_classes
=
num_classes
,
initializer
=
initializer
,
initializer
=
initializer
,
output
=
output
,
output
=
'logits'
,
name
=
'classification'
)
name
=
'sentence_prediction'
)
predictions
=
self
.
classifier
(
cls_output
)
predictions
=
self
.
classifier
(
cls_output
)
else
:
sequence_output
,
_
=
network
(
inputs
)
self
.
classifier
=
layers
.
ClassificationHead
(
inner_dim
=
sequence_output
.
shape
[
-
1
],
num_classes
=
num_classes
,
initializer
=
initializer
,
dropout_rate
=
dropout_rate
,
name
=
'sentence_prediction'
)
predictions
=
self
.
classifier
(
sequence_output
)
super
(
BertClassifier
,
self
).
__init__
(
super
(
BertClassifier
,
self
).
__init__
(
inputs
=
inputs
,
outputs
=
predictions
,
**
kwargs
)
inputs
=
inputs
,
outputs
=
predictions
,
**
kwargs
)
@
property
def
checkpoint_items
(
self
):
return
dict
(
encoder
=
self
.
_network
)
def
get_config
(
self
):
def
get_config
(
self
):
return
self
.
_config
return
self
.
_config
...
...
official/nlp/modeling/models/bert_classifier_test.py
View file @
5a2cf36f
...
@@ -42,8 +42,7 @@ class BertClassifierTest(keras_parameterized.TestCase):
...
@@ -42,8 +42,7 @@ class BertClassifierTest(keras_parameterized.TestCase):
# Create a BERT trainer with the created network.
# Create a BERT trainer with the created network.
bert_trainer_model
=
bert_classifier
.
BertClassifier
(
bert_trainer_model
=
bert_classifier
.
BertClassifier
(
test_network
,
test_network
,
num_classes
=
num_classes
)
num_classes
=
num_classes
)
# Create a set of 2-dimensional inputs (the first dimension is implicit).
# Create a set of 2-dimensional inputs (the first dimension is implicit).
word_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
word_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
...
@@ -89,7 +88,7 @@ class BertClassifierTest(keras_parameterized.TestCase):
...
@@ -89,7 +88,7 @@ class BertClassifierTest(keras_parameterized.TestCase):
# Create a BERT trainer with the created network. (Note that all the args
# Create a BERT trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.)
# are different, so we can catch any serialization mismatches.)
bert_trainer_model
=
bert_classifier
.
BertClassifier
(
bert_trainer_model
=
bert_classifier
.
BertClassifier
(
test_network
,
num_classes
=
4
,
initializer
=
'zeros'
,
output
=
'predictions'
)
test_network
,
num_classes
=
4
,
initializer
=
'zeros'
)
# Create another BERT trainer via serialization and deserialization.
# Create another BERT trainer via serialization and deserialization.
config
=
bert_trainer_model
.
get_config
()
config
=
bert_trainer_model
.
get_config
()
...
...
official/nlp/modeling/models/bert_pretrainer.py
View file @
5a2cf36f
...
@@ -41,6 +41,9 @@ class BertPretrainer(tf.keras.Model):
...
@@ -41,6 +41,9 @@ class BertPretrainer(tf.keras.Model):
instantiates the masked language model and classification networks that are
instantiates the masked language model and classification networks that are
used to create the training objectives.
used to create the training objectives.
*Note* that the model is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments:
Arguments:
network: A transformer network. This network should output a sequence output
network: A transformer network. This network should output a sequence output
and a classification output.
and a classification output.
...
@@ -147,11 +150,9 @@ class BertPretrainerV2(tf.keras.Model):
...
@@ -147,11 +150,9 @@ class BertPretrainerV2(tf.keras.Model):
(Experimental).
(Experimental).
Adds the masked language model head and optional classification heads upon the
Adds the masked language model head and optional classification heads upon the
transformer encoder. When num_masked_tokens == 0, there won't be MaskedLM
transformer encoder.
head.
Arguments:
Arguments:
num_masked_tokens: Number of tokens to predict from the masked LM.
encoder_network: A transformer network. This network should output a
encoder_network: A transformer network. This network should output a
sequence output and a classification output.
sequence output and a classification output.
mlm_activation: The activation (if any) to use in the masked LM network. If
mlm_activation: The activation (if any) to use in the masked LM network. If
...
@@ -169,7 +170,6 @@ class BertPretrainerV2(tf.keras.Model):
...
@@ -169,7 +170,6 @@ class BertPretrainerV2(tf.keras.Model):
def
__init__
(
def
__init__
(
self
,
self
,
num_masked_tokens
:
int
,
encoder_network
:
tf
.
keras
.
Model
,
encoder_network
:
tf
.
keras
.
Model
,
mlm_activation
=
None
,
mlm_activation
=
None
,
mlm_initializer
=
'glorot_uniform'
,
mlm_initializer
=
'glorot_uniform'
,
...
@@ -179,7 +179,6 @@ class BertPretrainerV2(tf.keras.Model):
...
@@ -179,7 +179,6 @@ class BertPretrainerV2(tf.keras.Model):
self
.
_self_setattr_tracking
=
False
self
.
_self_setattr_tracking
=
False
self
.
_config
=
{
self
.
_config
=
{
'encoder_network'
:
encoder_network
,
'encoder_network'
:
encoder_network
,
'num_masked_tokens'
:
num_masked_tokens
,
'mlm_initializer'
:
mlm_initializer
,
'mlm_initializer'
:
mlm_initializer
,
'classification_heads'
:
classification_heads
,
'classification_heads'
:
classification_heads
,
'name'
:
name
,
'name'
:
name
,
...
@@ -195,19 +194,16 @@ class BertPretrainerV2(tf.keras.Model):
...
@@ -195,19 +194,16 @@ class BertPretrainerV2(tf.keras.Model):
raise
ValueError
(
'Classification heads should have unique names.'
)
raise
ValueError
(
'Classification heads should have unique names.'
)
outputs
=
dict
()
outputs
=
dict
()
if
num_masked_tokens
>
0
:
self
.
masked_lm
=
layers
.
MaskedLM
(
self
.
masked_lm
=
layers
.
MaskedLM
(
embedding_table
=
self
.
encoder_network
.
get_embedding_table
(),
embedding_table
=
self
.
encoder_network
.
get_embedding_table
(),
activation
=
mlm_activation
,
activation
=
mlm_activation
,
initializer
=
mlm_initializer
,
initializer
=
mlm_initializer
,
name
=
'cls/predictions'
)
name
=
'cls/predictions'
)
masked_lm_positions
=
tf
.
keras
.
layers
.
Input
(
masked_lm_positions
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
None
,),
name
=
'masked_lm_positions'
,
dtype
=
tf
.
int32
)
shape
=
(
num_masked_tokens
,),
inputs
.
append
(
masked_lm_positions
)
name
=
'masked_lm_positions'
,
outputs
[
'lm_output'
]
=
self
.
masked_lm
(
dtype
=
tf
.
int32
)
sequence_output
,
masked_positions
=
masked_lm_positions
)
inputs
.
append
(
masked_lm_positions
)
outputs
[
'lm_output'
]
=
self
.
masked_lm
(
sequence_output
,
masked_positions
=
masked_lm_positions
)
for
cls_head
in
self
.
classification_heads
:
for
cls_head
in
self
.
classification_heads
:
outputs
[
cls_head
.
name
]
=
cls_head
(
sequence_output
)
outputs
[
cls_head
.
name
]
=
cls_head
(
sequence_output
)
...
@@ -217,7 +213,7 @@ class BertPretrainerV2(tf.keras.Model):
...
@@ -217,7 +213,7 @@ class BertPretrainerV2(tf.keras.Model):
@
property
@
property
def
checkpoint_items
(
self
):
def
checkpoint_items
(
self
):
"""Returns a dictionary of items to be additionally checkpointed."""
"""Returns a dictionary of items to be additionally checkpointed."""
items
=
dict
(
encoder
=
self
.
encoder_network
)
items
=
dict
(
encoder
=
self
.
encoder_network
,
masked_lm
=
self
.
masked_lm
)
for
head
in
self
.
classification_heads
:
for
head
in
self
.
classification_heads
:
for
key
,
item
in
head
.
checkpoint_items
.
items
():
for
key
,
item
in
head
.
checkpoint_items
.
items
():
items
[
'.'
.
join
([
head
.
name
,
key
])]
=
item
items
[
'.'
.
join
([
head
.
name
,
key
])]
=
item
...
...
official/nlp/modeling/models/bert_pretrainer_test.py
View file @
5a2cf36f
...
@@ -118,10 +118,9 @@ class BertPretrainerTest(keras_parameterized.TestCase):
...
@@ -118,10 +118,9 @@ class BertPretrainerTest(keras_parameterized.TestCase):
vocab_size
=
vocab_size
,
num_layers
=
2
,
sequence_length
=
sequence_length
)
vocab_size
=
vocab_size
,
num_layers
=
2
,
sequence_length
=
sequence_length
)
# Create a BERT trainer with the created network.
# Create a BERT trainer with the created network.
num_token_predictions
=
2
bert_trainer_model
=
bert_pretrainer
.
BertPretrainerV2
(
bert_trainer_model
=
bert_pretrainer
.
BertPretrainerV2
(
encoder_network
=
test_network
,
num_masked_tokens
=
num_token_predictions
)
encoder_network
=
test_network
)
num_token_predictions
=
20
# Create a set of 2-dimensional inputs (the first dimension is implicit).
# Create a set of 2-dimensional inputs (the first dimension is implicit).
word_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
word_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
mask
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
mask
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
...
@@ -145,7 +144,7 @@ class BertPretrainerTest(keras_parameterized.TestCase):
...
@@ -145,7 +144,7 @@ class BertPretrainerTest(keras_parameterized.TestCase):
# Create a BERT trainer with the created network. (Note that all the args
# Create a BERT trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.)
# are different, so we can catch any serialization mismatches.)
bert_trainer_model
=
bert_pretrainer
.
BertPretrainerV2
(
bert_trainer_model
=
bert_pretrainer
.
BertPretrainerV2
(
encoder_network
=
test_network
,
num_masked_tokens
=
2
)
encoder_network
=
test_network
)
# Create another BERT trainer via serialization and deserialization.
# Create another BERT trainer via serialization and deserialization.
config
=
bert_trainer_model
.
get_config
()
config
=
bert_trainer_model
.
get_config
()
...
...
official/nlp/modeling/models/bert_span_labeler.py
View file @
5a2cf36f
...
@@ -32,9 +32,12 @@ class BertSpanLabeler(tf.keras.Model):
...
@@ -32,9 +32,12 @@ class BertSpanLabeler(tf.keras.Model):
encoder as described in "BERT: Pre-training of Deep Bidirectional Transformers
encoder as described in "BERT: Pre-training of Deep Bidirectional Transformers
for Language Understanding" (https://arxiv.org/abs/1810.04805).
for Language Understanding" (https://arxiv.org/abs/1810.04805).
The BertSpanLabeler allows a user to pass in a transformer
stack
, and
The BertSpanLabeler allows a user to pass in a transformer
encoder
, and
instantiates a span labeling network based on a single dense layer.
instantiates a span labeling network based on a single dense layer.
*Note* that the model is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments:
Arguments:
network: A transformer network. This network should output a sequence output
network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding
and a classification output. Furthermore, it should expose its embedding
...
...
official/nlp/modeling/models/bert_token_classifier.py
View file @
5a2cf36f
...
@@ -36,6 +36,9 @@ class BertTokenClassifier(tf.keras.Model):
...
@@ -36,6 +36,9 @@ class BertTokenClassifier(tf.keras.Model):
instantiates a token classification network based on the passed `num_classes`
instantiates a token classification network based on the passed `num_classes`
argument.
argument.
*Note* that the model is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments:
Arguments:
network: A transformer network. This network should output a sequence output
network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding
and a classification output. Furthermore, it should expose its embedding
...
...
official/nlp/modeling/models/electra_pretrainer.py
View file @
5a2cf36f
...
@@ -39,6 +39,9 @@ class ElectraPretrainer(tf.keras.Model):
...
@@ -39,6 +39,9 @@ class ElectraPretrainer(tf.keras.Model):
model (at generator side) and classification networks (at discriminator side)
model (at generator side) and classification networks (at discriminator side)
that are used to create the training objectives.
that are used to create the training objectives.
*Note* that the model is constructed by Keras Subclass API, where layers are
defined inside __init__ and call() implements the computation.
Arguments:
Arguments:
generator_network: A transformer network for generator, this network should
generator_network: A transformer network for generator, this network should
output a sequence output and an optional classification output.
output a sequence output and an optional classification output.
...
@@ -48,7 +51,6 @@ class ElectraPretrainer(tf.keras.Model):
...
@@ -48,7 +51,6 @@ class ElectraPretrainer(tf.keras.Model):
num_classes: Number of classes to predict from the classification network
num_classes: Number of classes to predict from the classification network
for the generator network (not used now)
for the generator network (not used now)
sequence_length: Input sequence length
sequence_length: Input sequence length
last_hidden_dim: Last hidden dim of generator transformer output
num_token_predictions: Number of tokens to predict from the masked LM.
num_token_predictions: Number of tokens to predict from the masked LM.
mlm_activation: The activation (if any) to use in the masked LM and
mlm_activation: The activation (if any) to use in the masked LM and
classification networks. If None, no activation will be used.
classification networks. If None, no activation will be used.
...
@@ -66,7 +68,6 @@ class ElectraPretrainer(tf.keras.Model):
...
@@ -66,7 +68,6 @@ class ElectraPretrainer(tf.keras.Model):
vocab_size
,
vocab_size
,
num_classes
,
num_classes
,
sequence_length
,
sequence_length
,
last_hidden_dim
,
num_token_predictions
,
num_token_predictions
,
mlm_activation
=
None
,
mlm_activation
=
None
,
mlm_initializer
=
'glorot_uniform'
,
mlm_initializer
=
'glorot_uniform'
,
...
@@ -80,7 +81,6 @@ class ElectraPretrainer(tf.keras.Model):
...
@@ -80,7 +81,6 @@ class ElectraPretrainer(tf.keras.Model):
'vocab_size'
:
vocab_size
,
'vocab_size'
:
vocab_size
,
'num_classes'
:
num_classes
,
'num_classes'
:
num_classes
,
'sequence_length'
:
sequence_length
,
'sequence_length'
:
sequence_length
,
'last_hidden_dim'
:
last_hidden_dim
,
'num_token_predictions'
:
num_token_predictions
,
'num_token_predictions'
:
num_token_predictions
,
'mlm_activation'
:
mlm_activation
,
'mlm_activation'
:
mlm_activation
,
'mlm_initializer'
:
mlm_initializer
,
'mlm_initializer'
:
mlm_initializer
,
...
@@ -95,7 +95,6 @@ class ElectraPretrainer(tf.keras.Model):
...
@@ -95,7 +95,6 @@ class ElectraPretrainer(tf.keras.Model):
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
self
.
sequence_length
=
sequence_length
self
.
sequence_length
=
sequence_length
self
.
last_hidden_dim
=
last_hidden_dim
self
.
num_token_predictions
=
num_token_predictions
self
.
num_token_predictions
=
num_token_predictions
self
.
mlm_activation
=
mlm_activation
self
.
mlm_activation
=
mlm_activation
self
.
mlm_initializer
=
mlm_initializer
self
.
mlm_initializer
=
mlm_initializer
...
@@ -108,10 +107,15 @@ class ElectraPretrainer(tf.keras.Model):
...
@@ -108,10 +107,15 @@ class ElectraPretrainer(tf.keras.Model):
output
=
output_type
,
output
=
output_type
,
name
=
'generator_masked_lm'
)
name
=
'generator_masked_lm'
)
self
.
classification
=
layers
.
ClassificationHead
(
self
.
classification
=
layers
.
ClassificationHead
(
inner_dim
=
last_
hidden_
dim
,
inner_dim
=
generator_network
.
_config_dict
[
'
hidden_
size'
]
,
num_classes
=
num_classes
,
num_classes
=
num_classes
,
initializer
=
mlm_initializer
,
initializer
=
mlm_initializer
,
name
=
'generator_classification_head'
)
name
=
'generator_classification_head'
)
self
.
discriminator_projection
=
tf
.
keras
.
layers
.
Dense
(
units
=
discriminator_network
.
_config_dict
[
'hidden_size'
],
activation
=
mlm_activation
,
kernel_initializer
=
mlm_initializer
,
name
=
'discriminator_projection_head'
)
self
.
discriminator_head
=
tf
.
keras
.
layers
.
Dense
(
self
.
discriminator_head
=
tf
.
keras
.
layers
.
Dense
(
units
=
1
,
kernel_initializer
=
mlm_initializer
)
units
=
1
,
kernel_initializer
=
mlm_initializer
)
...
@@ -165,7 +169,8 @@ class ElectraPretrainer(tf.keras.Model):
...
@@ -165,7 +169,8 @@ class ElectraPretrainer(tf.keras.Model):
if
isinstance
(
disc_sequence_output
,
list
):
if
isinstance
(
disc_sequence_output
,
list
):
disc_sequence_output
=
disc_sequence_output
[
-
1
]
disc_sequence_output
=
disc_sequence_output
[
-
1
]
disc_logits
=
self
.
discriminator_head
(
disc_sequence_output
)
disc_logits
=
self
.
discriminator_head
(
self
.
discriminator_projection
(
disc_sequence_output
))
disc_logits
=
tf
.
squeeze
(
disc_logits
,
axis
=-
1
)
disc_logits
=
tf
.
squeeze
(
disc_logits
,
axis
=-
1
)
outputs
=
{
outputs
=
{
...
@@ -214,6 +219,12 @@ class ElectraPretrainer(tf.keras.Model):
...
@@ -214,6 +219,12 @@ class ElectraPretrainer(tf.keras.Model):
'sampled_tokens'
:
sampled_tokens
'sampled_tokens'
:
sampled_tokens
}
}
@
property
def
checkpoint_items
(
self
):
"""Returns a dictionary of items to be additionally checkpointed."""
items
=
dict
(
encoder
=
self
.
discriminator_network
)
return
items
def
get_config
(
self
):
def
get_config
(
self
):
return
self
.
_config
return
self
.
_config
...
...
official/nlp/modeling/models/electra_pretrainer_test.py
View file @
5a2cf36f
...
@@ -49,7 +49,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
...
@@ -49,7 +49,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
vocab_size
=
vocab_size
,
vocab_size
=
vocab_size
,
num_classes
=
num_classes
,
num_classes
=
num_classes
,
sequence_length
=
sequence_length
,
sequence_length
=
sequence_length
,
last_hidden_dim
=
768
,
num_token_predictions
=
num_token_predictions
,
num_token_predictions
=
num_token_predictions
,
disallow_correct
=
True
)
disallow_correct
=
True
)
...
@@ -101,7 +100,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
...
@@ -101,7 +100,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
vocab_size
=
100
,
vocab_size
=
100
,
num_classes
=
2
,
num_classes
=
2
,
sequence_length
=
3
,
sequence_length
=
3
,
last_hidden_dim
=
768
,
num_token_predictions
=
2
)
num_token_predictions
=
2
)
# Create a set of 2-dimensional data tensors to feed into the model.
# Create a set of 2-dimensional data tensors to feed into the model.
...
@@ -140,7 +138,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
...
@@ -140,7 +138,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
vocab_size
=
100
,
vocab_size
=
100
,
num_classes
=
2
,
num_classes
=
2
,
sequence_length
=
3
,
sequence_length
=
3
,
last_hidden_dim
=
768
,
num_token_predictions
=
2
)
num_token_predictions
=
2
)
# Create another BERT trainer via serialization and deserialization.
# Create another BERT trainer via serialization and deserialization.
...
...
official/nlp/modeling/networks/albert_transformer_encoder.py
View file @
5a2cf36f
...
@@ -40,6 +40,8 @@ class AlbertTransformerEncoder(tf.keras.Model):
...
@@ -40,6 +40,8 @@ class AlbertTransformerEncoder(tf.keras.Model):
The default values for this object are taken from the ALBERT-Base
The default values for this object are taken from the ALBERT-Base
implementation described in the paper.
implementation described in the paper.
*Note* that the network is constructed by Keras Functional API.
Arguments:
Arguments:
vocab_size: The size of the token vocabulary.
vocab_size: The size of the token vocabulary.
embedding_width: The width of the word embeddings. If the embedding width is
embedding_width: The width of the word embeddings. If the embedding width is
...
...
Prev
1
2
3
4
5
6
7
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment