Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
b0ccdb11
Commit
b0ccdb11
authored
Sep 28, 2020
by
Shixin Luo
Browse files
resolve conflict with master
parents
e61588cd
1611a8c5
Changes
210
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2168 additions
and
146 deletions
+2168
-146
official/nlp/modeling/networks/encoder_scaffold.py
official/nlp/modeling/networks/encoder_scaffold.py
+14
-4
official/nlp/modeling/networks/encoder_scaffold_test.py
official/nlp/modeling/networks/encoder_scaffold_test.py
+9
-11
official/nlp/modeling/networks/mobile_bert_encoder.py
official/nlp/modeling/networks/mobile_bert_encoder.py
+11
-20
official/nlp/modeling/networks/mobile_bert_encoder_test.py
official/nlp/modeling/networks/mobile_bert_encoder_test.py
+20
-28
official/nlp/modeling/networks/xlnet_base.py
official/nlp/modeling/networks/xlnet_base.py
+673
-0
official/nlp/modeling/networks/xlnet_base_test.py
official/nlp/modeling/networks/xlnet_base_test.py
+454
-0
official/nlp/modeling/ops/beam_search_test.py
official/nlp/modeling/ops/beam_search_test.py
+38
-2
official/nlp/nhnet/README.md
official/nlp/nhnet/README.md
+2
-4
official/nlp/nhnet/decoder.py
official/nlp/nhnet/decoder.py
+1
-2
official/nlp/nhnet/evaluation.py
official/nlp/nhnet/evaluation.py
+2
-8
official/nlp/nhnet/trainer.py
official/nlp/nhnet/trainer.py
+5
-5
official/nlp/projects/bigbird/attention.py
official/nlp/projects/bigbird/attention.py
+490
-0
official/nlp/projects/bigbird/attention_test.py
official/nlp/projects/bigbird/attention_test.py
+67
-0
official/nlp/projects/bigbird/encoder.py
official/nlp/projects/bigbird/encoder.py
+196
-0
official/nlp/projects/bigbird/encoder_test.py
official/nlp/projects/bigbird/encoder_test.py
+63
-0
official/nlp/tasks/masked_lm.py
official/nlp/tasks/masked_lm.py
+36
-33
official/nlp/tasks/utils.py
official/nlp/tasks/utils.py
+34
-8
official/nlp/train.py
official/nlp/train.py
+4
-3
official/nlp/train_ctl_continuous_finetune.py
official/nlp/train_ctl_continuous_finetune.py
+39
-11
official/nlp/train_ctl_continuous_finetune_test.py
official/nlp/train_ctl_continuous_finetune_test.py
+10
-7
No files found.
official/nlp/modeling/networks/encoder_scaffold.py
View file @
b0ccdb11
...
...
@@ -93,6 +93,7 @@ class EncoderScaffold(tf.keras.Model):
"kernel_initializer": The initializer for the transformer layers.
return_all_layer_outputs: Whether to output sequence embedding outputs of
all encoder transformer layers.
dict_outputs: Whether to use a dictionary as the model outputs.
"""
def
__init__
(
self
,
...
...
@@ -106,6 +107,7 @@ class EncoderScaffold(tf.keras.Model):
hidden_cls
=
layers
.
Transformer
,
hidden_cfg
=
None
,
return_all_layer_outputs
=
False
,
dict_outputs
=
False
,
**
kwargs
):
self
.
_self_setattr_tracking
=
False
self
.
_hidden_cls
=
hidden_cls
...
...
@@ -117,6 +119,7 @@ class EncoderScaffold(tf.keras.Model):
self
.
_embedding_cfg
=
embedding_cfg
self
.
_embedding_data
=
embedding_data
self
.
_return_all_layer_outputs
=
return_all_layer_outputs
self
.
_dict_outputs
=
dict_outputs
self
.
_kwargs
=
kwargs
if
embedding_cls
:
...
...
@@ -138,7 +141,7 @@ class EncoderScaffold(tf.keras.Model):
shape
=
(
seq_length
,),
dtype
=
tf
.
int32
,
name
=
'input_type_ids'
)
inputs
=
[
word_ids
,
mask
,
type_ids
]
self
.
_embedding_layer
=
layers
.
OnDeviceEmbedding
(
self
.
_embedding_layer
=
keras_nlp
.
layers
.
OnDeviceEmbedding
(
vocab_size
=
embedding_cfg
[
'vocab_size'
],
embedding_width
=
embedding_cfg
[
'hidden_size'
],
initializer
=
embedding_cfg
[
'initializer'
],
...
...
@@ -147,13 +150,13 @@ class EncoderScaffold(tf.keras.Model):
word_embeddings
=
self
.
_embedding_layer
(
word_ids
)
# Always uses dynamic slicing for simplicity.
self
.
_position_embedding_layer
=
keras_nlp
.
PositionEmbedding
(
self
.
_position_embedding_layer
=
keras_nlp
.
layers
.
PositionEmbedding
(
initializer
=
embedding_cfg
[
'initializer'
],
max_length
=
embedding_cfg
[
'max_seq_length'
],
name
=
'position_embedding'
)
position_embeddings
=
self
.
_position_embedding_layer
(
word_embeddings
)
self
.
_type_embedding_layer
=
layers
.
OnDeviceEmbedding
(
self
.
_type_embedding_layer
=
keras_nlp
.
layers
.
OnDeviceEmbedding
(
vocab_size
=
embedding_cfg
[
'type_vocab_size'
],
embedding_width
=
embedding_cfg
[
'hidden_size'
],
initializer
=
embedding_cfg
[
'initializer'
],
...
...
@@ -200,7 +203,13 @@ class EncoderScaffold(tf.keras.Model):
name
=
'cls_transform'
)
cls_output
=
self
.
_pooler_layer
(
first_token_tensor
)
if
return_all_layer_outputs
:
if
dict_outputs
:
outputs
=
dict
(
sequence_output
=
layer_output_data
[
-
1
],
pooled_output
=
cls_output
,
encoder_outputs
=
layer_output_data
,
)
elif
return_all_layer_outputs
:
outputs
=
[
layer_output_data
,
cls_output
]
else
:
outputs
=
[
layer_output_data
[
-
1
],
cls_output
]
...
...
@@ -219,6 +228,7 @@ class EncoderScaffold(tf.keras.Model):
'embedding_cfg'
:
self
.
_embedding_cfg
,
'hidden_cfg'
:
self
.
_hidden_cfg
,
'return_all_layer_outputs'
:
self
.
_return_all_layer_outputs
,
'dict_outputs'
:
self
.
_dict_outputs
,
}
if
inspect
.
isclass
(
self
.
_hidden_cls
):
config_dict
[
'hidden_cls_string'
]
=
tf
.
keras
.
utils
.
get_registered_name
(
...
...
official/nlp/modeling/networks/encoder_scaffold_test.py
View file @
b0ccdb11
...
...
@@ -12,11 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for transformer-based text encoder network."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
"""Tests for EncoderScaffold network."""
from
absl.testing
import
parameterized
import
numpy
as
np
...
...
@@ -218,16 +214,17 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
pooler_layer_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
),
hidden_cfg
=
hidden_cfg
,
embedding_cfg
=
embedding_cfg
)
embedding_cfg
=
embedding_cfg
,
dict_outputs
=
True
)
# Create the inputs (note that the first dimension is implicit).
word_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
mask
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
type_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
data
,
pooled
=
test_network
([
word_ids
,
mask
,
type_ids
])
outputs
=
test_network
([
word_ids
,
mask
,
type_ids
])
# Create a model based off of this network:
model
=
tf
.
keras
.
Model
([
word_ids
,
mask
,
type_ids
],
[
data
,
pooled
]
)
model
=
tf
.
keras
.
Model
([
word_ids
,
mask
,
type_ids
],
outputs
)
# Invoke the model. We can't validate the output data here (the model is too
# complex) but this will catch structural runtime errors.
...
...
@@ -237,7 +234,8 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
))
type_id_data
=
np
.
random
.
randint
(
num_types
,
size
=
(
batch_size
,
sequence_length
))
_
=
model
.
predict
([
word_id_data
,
mask_data
,
type_id_data
])
preds
=
model
.
predict
([
word_id_data
,
mask_data
,
type_id_data
])
self
.
assertEqual
(
preds
[
"pooled_output"
].
shape
,
(
3
,
hidden_size
))
# Creates a EncoderScaffold with max_sequence_length != sequence_length
num_types
=
7
...
...
@@ -272,8 +270,8 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
stddev
=
0.02
),
hidden_cfg
=
hidden_cfg
,
embedding_cfg
=
embedding_cfg
)
model
=
tf
.
keras
.
Model
([
word_ids
,
mask
,
type_ids
],
[
data
,
pooled
]
)
outputs
=
test_network
([
word_ids
,
mask
,
type_ids
])
model
=
tf
.
keras
.
Model
([
word_ids
,
mask
,
type_ids
],
outputs
)
_
=
model
.
predict
([
word_id_data
,
mask_data
,
type_id_data
])
def
test_serialize_deserialize
(
self
):
...
...
official/nlp/modeling/networks/mobile_bert_encoder.py
View file @
b0ccdb11
...
...
@@ -101,18 +101,18 @@ class MobileBertEmbedding(tf.keras.layers.Layer):
self
.
max_sequence_length
=
max_sequence_length
self
.
dropout_rate
=
dropout_rate
self
.
word_embedding
=
layers
.
OnDeviceEmbedding
(
self
.
word_embedding
=
keras_nlp
.
layers
.
OnDeviceEmbedding
(
self
.
word_vocab_size
,
self
.
word_embed_size
,
initializer
=
initializer
,
name
=
'word_embedding'
)
self
.
type_embedding
=
layers
.
OnDeviceEmbedding
(
self
.
type_embedding
=
keras_nlp
.
layers
.
OnDeviceEmbedding
(
self
.
type_vocab_size
,
self
.
output_embed_size
,
use_one_hot
=
True
,
initializer
=
initializer
,
name
=
'type_embedding'
)
self
.
pos_embedding
=
keras_nlp
.
PositionEmbedding
(
self
.
pos_embedding
=
keras_nlp
.
layers
.
PositionEmbedding
(
max_length
=
max_sequence_length
,
initializer
=
initializer
,
name
=
'position_embedding'
)
...
...
@@ -127,7 +127,7 @@ class MobileBertEmbedding(tf.keras.layers.Layer):
self
.
dropout_rate
,
name
=
'embedding_dropout'
)
def
call
(
self
,
input_ids
,
token_type_ids
=
None
,
training
=
False
):
def
call
(
self
,
input_ids
,
token_type_ids
=
None
):
word_embedding_out
=
self
.
word_embedding
(
input_ids
)
word_embedding_out
=
tf
.
concat
(
[
tf
.
pad
(
word_embedding_out
[:,
1
:],
((
0
,
0
),
(
0
,
1
),
(
0
,
0
))),
...
...
@@ -142,7 +142,7 @@ class MobileBertEmbedding(tf.keras.layers.Layer):
type_embedding_out
=
self
.
type_embedding
(
token_type_ids
)
embedding_out
+=
type_embedding_out
embedding_out
=
self
.
layer_norm
(
embedding_out
)
embedding_out
=
self
.
dropout_layer
(
embedding_out
,
training
=
training
)
embedding_out
=
self
.
dropout_layer
(
embedding_out
)
return
embedding_out
...
...
@@ -300,7 +300,6 @@ class TransformerLayer(tf.keras.layers.Layer):
def
call
(
self
,
input_tensor
,
attention_mask
=
None
,
training
=
False
,
return_attention_scores
=
False
):
"""Implementes the forward pass.
...
...
@@ -309,7 +308,6 @@ class TransformerLayer(tf.keras.layers.Layer):
attention_mask: (optional) int32 tensor of shape [batch_size, seq_length,
seq_length], with 1 for positions that can be attended to and 0 in
positions that should not be.
training: If the model is in training mode.
return_attention_scores: If return attention score.
Returns:
...
...
@@ -326,7 +324,6 @@ class TransformerLayer(tf.keras.layers.Layer):
f
'hidden size
{
self
.
hidden_size
}
'
))
prev_output
=
input_tensor
# input bottleneck
dense_layer
=
self
.
block_layers
[
'bottleneck_input'
][
0
]
layer_norm
=
self
.
block_layers
[
'bottleneck_input'
][
1
]
...
...
@@ -355,7 +352,6 @@ class TransformerLayer(tf.keras.layers.Layer):
key_tensor
,
attention_mask
,
return_attention_scores
=
True
,
training
=
training
)
attention_output
=
layer_norm
(
attention_output
+
layer_input
)
...
...
@@ -375,7 +371,7 @@ class TransformerLayer(tf.keras.layers.Layer):
dropout_layer
=
self
.
block_layers
[
'bottleneck_output'
][
1
]
layer_norm
=
self
.
block_layers
[
'bottleneck_output'
][
2
]
layer_output
=
bottleneck
(
layer_output
)
layer_output
=
dropout_layer
(
layer_output
,
training
=
training
)
layer_output
=
dropout_layer
(
layer_output
)
layer_output
=
layer_norm
(
layer_output
+
prev_output
)
if
return_attention_scores
:
...
...
@@ -406,8 +402,6 @@ class MobileBERTEncoder(tf.keras.Model):
num_feedforward_networks
=
4
,
normalization_type
=
'no_norm'
,
classifier_activation
=
False
,
return_all_layers
=
False
,
return_attention_score
=
False
,
**
kwargs
):
"""Class initialization.
...
...
@@ -438,8 +432,6 @@ class MobileBERTEncoder(tf.keras.Model):
MobileBERT paper. 'layer_norm' is used for the teacher model.
classifier_activation: If using the tanh activation for the final
representation of the [CLS] token in fine-tuning.
return_all_layers: If return all layer outputs.
return_attention_score: If return attention scores for each layer.
**kwargs: Other keyworded and arguments.
"""
self
.
_self_setattr_tracking
=
False
...
...
@@ -513,12 +505,11 @@ class MobileBERTEncoder(tf.keras.Model):
else
:
self
.
_pooler_layer
=
None
if
return_all_layers
:
outputs
=
[
all_layer_outputs
,
first_token
]
else
:
outputs
=
[
prev_output
,
first_token
]
if
return_attention_score
:
outputs
.
append
(
all_attention_scores
)
outputs
=
dict
(
sequence_output
=
prev_output
,
pooled_output
=
first_token
,
encoder_outputs
=
all_layer_outputs
,
attention_scores
=
all_attention_scores
)
super
(
MobileBERTEncoder
,
self
).
__init__
(
inputs
=
self
.
inputs
,
outputs
=
outputs
,
**
kwargs
)
...
...
official/nlp/modeling/networks/mobile_bert_encoder_test.py
View file @
b0ccdb11
...
...
@@ -32,7 +32,7 @@ def generate_fake_input(batch_size=1, seq_len=5, vocab_size=10000, seed=0):
return
fake_input
class
Mo
deling
Test
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
class
Mo
bileBertEncoder
Test
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
def
test_embedding_layer_with_token_type
(
self
):
layer
=
mobile_bert_encoder
.
MobileBertEmbedding
(
10
,
8
,
2
,
16
)
...
...
@@ -116,7 +116,9 @@ class ModelingTest(parameterized.TestCase, tf.test.TestCase):
word_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
mask
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
type_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
layer_output
,
pooler_output
=
test_network
([
word_ids
,
mask
,
type_ids
])
outputs
=
test_network
([
word_ids
,
mask
,
type_ids
])
layer_output
,
pooler_output
=
outputs
[
'sequence_output'
],
outputs
[
'pooled_output'
]
self
.
assertIsInstance
(
test_network
.
transformer_layers
,
list
)
self
.
assertLen
(
test_network
.
transformer_layers
,
num_blocks
)
...
...
@@ -134,13 +136,13 @@ class ModelingTest(parameterized.TestCase, tf.test.TestCase):
test_network
=
mobile_bert_encoder
.
MobileBERTEncoder
(
word_vocab_size
=
100
,
hidden_size
=
hidden_size
,
num_blocks
=
num_blocks
,
return_all_layers
=
True
)
num_blocks
=
num_blocks
)
word_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
mask
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
type_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
all_layer_output
,
_
=
test_network
([
word_ids
,
mask
,
type_ids
])
outputs
=
test_network
([
word_ids
,
mask
,
type_ids
])
all_layer_output
=
outputs
[
'encoder_outputs'
]
self
.
assertIsInstance
(
all_layer_output
,
list
)
self
.
assertLen
(
all_layer_output
,
num_blocks
+
1
)
...
...
@@ -153,16 +155,13 @@ class ModelingTest(parameterized.TestCase, tf.test.TestCase):
test_network
=
mobile_bert_encoder
.
MobileBERTEncoder
(
word_vocab_size
=
vocab_size
,
hidden_size
=
hidden_size
,
num_blocks
=
num_blocks
,
return_all_layers
=
False
)
num_blocks
=
num_blocks
)
word_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
mask
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
type_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
layer_out_tensor
,
pooler_out_tensor
=
test_network
(
[
word_ids
,
mask
,
type_ids
])
model
=
tf
.
keras
.
Model
([
word_ids
,
mask
,
type_ids
],
[
layer_out_tensor
,
pooler_out_tensor
])
outputs
=
test_network
([
word_ids
,
mask
,
type_ids
])
model
=
tf
.
keras
.
Model
([
word_ids
,
mask
,
type_ids
],
outputs
)
input_seq
=
generate_fake_input
(
batch_size
=
1
,
seq_len
=
sequence_length
,
vocab_size
=
vocab_size
)
...
...
@@ -170,13 +169,12 @@ class ModelingTest(parameterized.TestCase, tf.test.TestCase):
batch_size
=
1
,
seq_len
=
sequence_length
,
vocab_size
=
2
)
token_type
=
generate_fake_input
(
batch_size
=
1
,
seq_len
=
sequence_length
,
vocab_size
=
2
)
layer_output
,
pooler_output
=
model
.
predict
(
[
input_seq
,
input_mask
,
token_type
])
outputs
=
model
.
predict
([
input_seq
,
input_mask
,
token_type
])
layer
_output_shape
=
[
1
,
sequence_length
,
hidden_size
]
self
.
assertAllEqual
(
layer
_output
.
shape
,
layer
_output_shape
)
poole
r
_output_shape
=
[
1
,
hidden_size
]
self
.
assertAllEqual
(
poole
r
_output
.
shape
,
poole
r
_output_shape
)
sequence
_output_shape
=
[
1
,
sequence_length
,
hidden_size
]
self
.
assertAllEqual
(
outputs
[
'sequence
_output
'
]
.
shape
,
sequence
_output_shape
)
poole
d
_output_shape
=
[
1
,
hidden_size
]
self
.
assertAllEqual
(
outputs
[
'
poole
d
_output
'
]
.
shape
,
poole
d
_output_shape
)
def
test_mobilebert_encoder_invocation_with_attention_score
(
self
):
vocab_size
=
100
...
...
@@ -186,18 +184,13 @@ class ModelingTest(parameterized.TestCase, tf.test.TestCase):
test_network
=
mobile_bert_encoder
.
MobileBERTEncoder
(
word_vocab_size
=
vocab_size
,
hidden_size
=
hidden_size
,
num_blocks
=
num_blocks
,
return_all_layers
=
False
,
return_attention_score
=
True
)
num_blocks
=
num_blocks
)
word_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
mask
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
type_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
layer_out_tensor
,
pooler_out_tensor
,
attention_out_tensor
=
test_network
(
[
word_ids
,
mask
,
type_ids
])
model
=
tf
.
keras
.
Model
(
[
word_ids
,
mask
,
type_ids
],
[
layer_out_tensor
,
pooler_out_tensor
,
attention_out_tensor
])
outputs
=
test_network
([
word_ids
,
mask
,
type_ids
])
model
=
tf
.
keras
.
Model
([
word_ids
,
mask
,
type_ids
],
outputs
)
input_seq
=
generate_fake_input
(
batch_size
=
1
,
seq_len
=
sequence_length
,
vocab_size
=
vocab_size
)
...
...
@@ -205,9 +198,8 @@ class ModelingTest(parameterized.TestCase, tf.test.TestCase):
batch_size
=
1
,
seq_len
=
sequence_length
,
vocab_size
=
2
)
token_type
=
generate_fake_input
(
batch_size
=
1
,
seq_len
=
sequence_length
,
vocab_size
=
2
)
_
,
_
,
attention_score_output
=
model
.
predict
(
[
input_seq
,
input_mask
,
token_type
])
self
.
assertLen
(
attention_score_output
,
num_blocks
)
outputs
=
model
.
predict
([
input_seq
,
input_mask
,
token_type
])
self
.
assertLen
(
outputs
[
'attention_scores'
],
num_blocks
)
@
parameterized
.
named_parameters
(
(
'sequence_classification'
,
models
.
BertClassifier
,
[
None
,
5
]),
...
...
official/nlp/modeling/networks/xlnet_base.py
0 → 100644
View file @
b0ccdb11
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras-based XLNet Model."""
from
absl
import
logging
import
tensorflow
as
tf
from
official.nlp.modeling
import
layers
from
official.nlp.modeling.layers
import
transformer_xl
_SEG_ID_CLS
=
2
def
_create_causal_attention_mask
(
seq_length
,
memory_length
,
dtype
=
tf
.
float32
,
same_length
=
False
):
"""Creates a causal attention mask with a single-sided context.
When applying the attention mask in `MultiHeadRelativeAttention`, the
attention scores are of shape `[(batch dimensions), S, S + M]`, where:
- S = sequence length.
- M = memory length.
In a simple case where S = 2, M = 1, here is a simple illustration of the
`attention_scores` matrix, where `a` represents an attention function:
token_0 [[a(token_0, mem_0) a(token_0, token_0) a(token_0, token_1)],
token_1 [a(token_1, mem_0) a(token_1, token_0) a(token_1, token_1)]]
mem_0 token_0 token_1
For uni-directional attention, we want to mask out values in the attention
scores that represent a(token_i, token_j) where j > i. We can achieve this by
concatenating 0s (representing memory positions) with a strictly upper
triangular matrix of 1s.
Arguments:
seq_length: int, The length of each sequence.
memory_length: int, The length of memory blocks.
dtype: dtype of the mask.
same_length: bool, whether to use the same attention length for each token.
Returns:
A unidirectional attention mask of shape
`[seq_length, seq_length + memory_length]`. E.g.:
[[0. 0. 0. 1. 1. 1.]
[0. 0. 0. 0. 1. 1.]
[0. 0. 0. 0. 0. 1.]
[0. 0. 0. 0. 0. 0.]]
"""
ones_matrix
=
tf
.
ones
([
seq_length
,
seq_length
],
dtype
=
dtype
)
upper_triangular
=
tf
.
linalg
.
band_part
(
ones_matrix
,
0
,
-
1
)
diagonal
=
tf
.
linalg
.
band_part
(
ones_matrix
,
0
,
0
)
padding
=
tf
.
zeros
([
seq_length
,
memory_length
],
dtype
=
dtype
)
causal_attention_mask
=
tf
.
concat
(
[
padding
,
upper_triangular
-
diagonal
],
1
)
if
same_length
:
lower_triangular
=
tf
.
linalg
.
band_part
(
ones_matrix
,
-
1
,
0
)
strictly_lower_triangular
=
lower_triangular
-
diagonal
causal_attention_mask
=
tf
.
concat
(
[
causal_attention_mask
[:,
:
seq_length
]
+
strictly_lower_triangular
,
causal_attention_mask
[:,
seq_length
:]],
1
)
return
causal_attention_mask
def
_compute_attention_mask
(
input_mask
,
permutation_mask
,
attention_type
,
seq_length
,
memory_length
,
batch_size
,
dtype
=
tf
.
float32
):
"""Combines all input attention masks for XLNet.
In XLNet modeling, `0` represents tokens that can be attended, and `1`
represents tokens that cannot be attended.
For XLNet pre-training and fine tuning, there are a few masks used:
- Causal attention mask: If the attention type is unidirectional, then all
tokens after the current position cannot be attended to.
- Input mask: when generating data, padding is added to a max sequence length
to make all sequences the same length. This masks out real tokens (`0`) from
padding tokens (`1`).
- Permutation mask: during XLNet pretraining, the input sequence is factorized
into a factorization sequence `z`. During partial prediction, `z` is split
at a cutting point `c` (an index of the factorization sequence) and
prediction is only applied to all tokens after `c`. Therefore, tokens at
factorization positions `i` > `c` can be attended to and tokens at
factorization positions `i` <= `c` cannot be attended to.
This function broadcasts and combines all attention masks to produce the
query attention mask and the content attention mask.
Args:
input_mask: Tensor, the input mask related to padding. Input shape:
`(B, S)`.
permutation_mask: Tensor, the permutation mask used in partial prediction.
Input shape: `(B, S, S)`.
attention_type: str, the attention type. Can be "uni" (directional) or
"bi" (directional).
seq_length: int, the length of each sequence.
memory_length: int the length of memory blocks.
batch_size: int, the batch size.
dtype: The dtype of the masks.
Returns:
attention_mask, content_attention_mask: The position and context-based
attention masks and content attention masks, respectively.
"""
attention_mask
=
None
# `1` values mean do not attend to this position.
if
attention_type
==
"uni"
:
causal_attention_mask
=
_create_causal_attention_mask
(
seq_length
=
seq_length
,
memory_length
=
memory_length
,
dtype
=
dtype
)
causal_attention_mask
=
causal_attention_mask
[
None
,
None
,
:,
:]
# `causal_attention_mask`: [1, 1, S, S + M]
# input_mask: [B, S]
# permutation_mask: [B, S, S]
if
input_mask
is
not
None
and
permutation_mask
is
not
None
:
data_mask
=
input_mask
[:,
None
,
:]
+
permutation_mask
elif
input_mask
is
not
None
and
permutation_mask
is
None
:
data_mask
=
input_mask
[:,
None
,
:]
elif
input_mask
is
None
and
permutation_mask
is
not
None
:
data_mask
=
permutation_mask
else
:
data_mask
=
None
# data_mask: [B, S, S] or [B, 1, S]
if
data_mask
is
not
None
:
# All positions within state can be attended to.
state_mask
=
tf
.
zeros
([
batch_size
,
tf
.
shape
(
data_mask
)[
1
],
memory_length
],
dtype
=
dtype
)
# state_mask: [B, 1, M] or [B, S, M]
data_mask
=
tf
.
concat
([
state_mask
,
data_mask
],
2
)
# data_mask: [B, 1, S + M] or [B, S, S + M]
if
attention_type
==
"uni"
:
attention_mask
=
causal_attention_mask
+
data_mask
[:,
None
,
:,
:]
else
:
attention_mask
=
data_mask
[:,
None
,
:,
:]
# Construct the content attention mask.
if
attention_mask
is
not
None
:
attention_mask
=
tf
.
cast
(
attention_mask
>
0
,
dtype
=
dtype
)
non_tgt_mask
=
-
tf
.
eye
(
seq_length
,
dtype
=
dtype
)
non_tgt_mask
=
tf
.
concat
(
[
tf
.
zeros
([
seq_length
,
memory_length
],
dtype
=
dtype
),
non_tgt_mask
],
axis
=-
1
)
content_attention_mask
=
tf
.
cast
(
(
attention_mask
+
non_tgt_mask
[
None
,
None
,
:,
:])
>
0
,
dtype
=
dtype
)
else
:
content_attention_mask
=
None
return
attention_mask
,
content_attention_mask
def
_compute_segment_matrix
(
segment_ids
,
memory_length
,
batch_size
,
use_cls_mask
):
"""Computes the segment embedding matrix.
XLNet introduced segment-based attention for attention calculations. This
extends the idea of relative encodings in Transformer XL by considering
whether or not two positions are within the same segment, rather than
which segments they come from.
This function generates a segment matrix by broadcasting provided segment IDs
in two different dimensions and checking where values are equal. This output
matrix shows `True` whenever two tokens are NOT in the same segment and
`False` whenever they are.
Args:
segment_ids: A Tensor of size `[B, S]` that represents which segment
each token belongs to.
memory_length: int, the length of memory blocks.
batch_size: int, the batch size.
use_cls_mask: bool, whether or not to introduce cls mask in
input sequences.
Returns:
A boolean Tensor of size `[B, S, S + M]`, where `True` means that two
tokens are NOT in the same segment, and `False` means they are in the same
segment.
"""
if
segment_ids
is
None
:
return
None
memory_padding
=
tf
.
zeros
([
batch_size
,
memory_length
],
dtype
=
tf
.
int32
)
padded_segment_ids
=
tf
.
concat
([
memory_padding
,
segment_ids
],
1
)
# segment_ids: [B, S]
# padded_segment_ids: [B, S + M]
if
use_cls_mask
:
# `1` indicates not in the same segment.
# Target result: [B, S, S + M]
# segment_ids: [B, S]
# padded_segment_ids: [B, S + M]
broadcasted_segment_class_indices
=
(
tf
.
equal
(
segment_ids
,
tf
.
constant
([
_SEG_ID_CLS
]))[:,
:,
None
])
broadcasted_padded_class_indices
=
(
tf
.
equal
(
padded_segment_ids
,
tf
.
constant
([
_SEG_ID_CLS
]))[:,
None
,
:])
class_index_matrix
=
tf
.
logical_or
(
broadcasted_segment_class_indices
,
broadcasted_padded_class_indices
)
segment_matrix
=
tf
.
equal
(
segment_ids
[:,
:,
None
],
padded_segment_ids
[:,
None
,
:])
segment_matrix
=
tf
.
logical_or
(
class_index_matrix
,
segment_matrix
)
else
:
# TODO(allencwang) - address this legacy mismatch from `use_cls_mask`.
segment_matrix
=
tf
.
logical_not
(
tf
.
equal
(
segment_ids
[:,
:,
None
],
padded_segment_ids
[:,
None
,
:]))
return
segment_matrix
def
_compute_positional_encoding
(
attention_type
,
position_encoding_layer
,
hidden_size
,
batch_size
,
total_length
,
seq_length
,
clamp_length
,
bi_data
,
dtype
=
tf
.
float32
):
"""Computes the relative position encoding.
Args:
attention_type: str, the attention type. Can be "uni" (directional) or
"bi" (directional).
position_encoding_layer: An instance of `RelativePositionEncoding`.
hidden_size: int, the hidden size.
batch_size: int, the batch size.
total_length: int, the sequence length added to the memory length.
seq_length: int, the length of each sequence.
clamp_length: int, clamp all relative distances larger than clamp_length. -1
means no clamping.
bi_data: bool, whether to use bidirectional input pipeline. Usually set to
True during pretraining and False during finetuning.
dtype: the dtype of the encoding.
Returns:
A Tensor, representing the position encoding.
"""
freq_seq
=
tf
.
range
(
0
,
hidden_size
,
2.0
)
if
dtype
is
not
None
and
dtype
!=
tf
.
float32
:
freq_seq
=
tf
.
cast
(
freq_seq
,
dtype
=
dtype
)
if
attention_type
==
"bi"
:
beg
,
end
=
total_length
,
-
seq_length
elif
attention_type
==
"uni"
:
beg
,
end
=
total_length
,
-
1
else
:
raise
ValueError
(
"Unknown `attention_type` {}."
.
format
(
attention_type
))
if
bi_data
:
forward_position_sequence
=
tf
.
range
(
beg
,
end
,
-
1.0
)
backward_position_sequence
=
tf
.
range
(
-
beg
,
-
end
,
1.0
)
if
dtype
is
not
None
and
dtype
!=
tf
.
float32
:
forward_position_sequence
=
tf
.
cast
(
forward_position_sequence
,
dtype
=
dtype
)
backward_position_sequence
=
tf
.
cast
(
backward_position_sequence
,
dtype
=
dtype
)
if
clamp_length
>
0
:
forward_position_sequence
=
tf
.
clip_by_value
(
forward_position_sequence
,
-
clamp_length
,
clamp_length
)
backward_position_sequence
=
tf
.
clip_by_value
(
backward_position_sequence
,
-
clamp_length
,
clamp_length
)
if
batch_size
is
not
None
:
forward_positional_encoding
=
position_encoding_layer
(
forward_position_sequence
,
batch_size
//
2
)
backward_positional_encoding
=
position_encoding_layer
(
backward_position_sequence
,
batch_size
//
2
)
else
:
forward_positional_encoding
=
position_encoding_layer
(
forward_position_sequence
,
None
)
backward_positional_encoding
=
position_encoding_layer
(
backward_position_sequence
,
None
)
relative_position_encoding
=
tf
.
concat
(
[
forward_positional_encoding
,
backward_positional_encoding
],
axis
=
0
)
else
:
forward_position_sequence
=
tf
.
range
(
beg
,
end
,
-
1.0
)
if
dtype
is
not
None
and
dtype
!=
tf
.
float32
:
forward_position_sequence
=
tf
.
cast
(
forward_position_sequence
,
dtype
=
dtype
)
if
clamp_length
>
0
:
forward_position_sequence
=
tf
.
clip_by_value
(
forward_position_sequence
,
-
clamp_length
,
clamp_length
)
relative_position_encoding
=
position_encoding_layer
(
forward_position_sequence
,
batch_size
)
return
relative_position_encoding
class
RelativePositionEncoding
(
tf
.
keras
.
layers
.
Layer
):
"""Creates a relative positional encoding.
This layer creates a relative positional encoding as described in
"Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
(https://arxiv.org/abs/1901.02860).
Rather than an absolute position embedding as in Transformer, this
formulation represents position as the relative distance between tokens using
sinusoidal positional embeddings.
Note: This layer is currently experimental.
Attributes:
hidden_size: The dimensionality of the input embeddings.
"""
def
__init__
(
self
,
hidden_size
,
**
kwargs
):
super
(
RelativePositionEncoding
,
self
).
__init__
(
**
kwargs
)
self
.
_hidden_size
=
hidden_size
self
.
_inv_freq
=
1.0
/
(
10000.0
**
(
tf
.
range
(
0
,
self
.
_hidden_size
,
2.0
)
/
self
.
_hidden_size
))
def
call
(
self
,
pos_seq
,
batch_size
=
None
):
"""Implements call() for the layer.
Arguments:
pos_seq: A 1-D `Tensor`
batch_size: The optionally provided batch size that tiles the relative
positional encoding.
Returns:
The relative positional encoding of shape:
[batch_size, len(pos_seq), hidden_size] if batch_size is provided, else
[1, len(pos_seq), hidden_size].
"""
sinusoid_input
=
tf
.
einsum
(
"i,d->id"
,
pos_seq
,
self
.
_inv_freq
)
relative_position_encoding
=
tf
.
concat
([
tf
.
sin
(
sinusoid_input
),
tf
.
cos
(
sinusoid_input
)],
-
1
)
relative_position_encoding
=
relative_position_encoding
[
None
,
:,
:]
if
batch_size
is
not
None
:
relative_position_encoding
=
tf
.
tile
(
relative_position_encoding
,
[
batch_size
,
1
,
1
])
return
relative_position_encoding
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
class
XLNetBase
(
tf
.
keras
.
layers
.
Layer
):
"""Base XLNet model.
Attributes:
vocab_size: int, the number of tokens in vocabulary.
num_layers: int, the number of layers.
hidden_size: int, the hidden size.
num_attention_heads: int, the number of attention heads.
head_size: int, the dimension size of each attention head.
inner_size: int, the hidden size in feed-forward layers.
dropout_rate: float, dropout rate.
attention_dropout_rate: float, dropout rate on attention probabilities.
attention_type: str, "uni" or "bi".
bi_data: bool, whether to use bidirectional input pipeline. Usually set to
True during pretraining and False during finetuning.
initializer: A tf initializer.
two_stream: bool, whether or not to use `TwoStreamRelativeAttention` used
in the XLNet pretrainer. If `False`, then it will use
`MultiHeadRelativeAttention` as in Transformer XL.
tie_attention_biases: bool, whether or not to tie the biases together.
Usually set to `True`. Used for backwards compatibility.
memory_length: int, the number of tokens to cache.
same_length: bool, whether to use the same attention length for each
token.
clamp_length: int, clamp all relative distances larger than clamp_length. -1
means no clamping.
reuse_length: int, the number of tokens in the currect batch to be cached
and reused in the future.
inner_activation: str, "relu" or "gelu".
use_cls_mask: bool, whether or not cls mask is included in the
input sequences.
embedding_width: The width of the word embeddings. If the embedding width
is not equal to hidden size, embedding parameters will be factorized
into two matrices in the shape of ["vocab_size", "embedding_width"] and
["embedding_width", "hidden_size"] ("embedding_width" is usually much
smaller than "hidden_size").
embedding_layer: The word embedding layer. `None` means we will create a
new embedding layer. Otherwise, we will reuse the given embedding layer.
This parameter is originally added for ELECTRA model which needs to tie
the generator embeddings with the discriminator embeddings.
"""
def
__init__
(
self
,
vocab_size
,
num_layers
,
hidden_size
,
num_attention_heads
,
head_size
,
inner_size
,
dropout_rate
,
attention_dropout_rate
,
attention_type
,
bi_data
,
initializer
,
two_stream
=
False
,
tie_attention_biases
=
True
,
memory_length
=
None
,
clamp_length
=-
1
,
reuse_length
=
None
,
inner_activation
=
"relu"
,
use_cls_mask
=
False
,
embedding_width
=
None
,
**
kwargs
):
super
(
XLNetBase
,
self
).
__init__
(
**
kwargs
)
self
.
_vocab_size
=
vocab_size
self
.
_initializer
=
initializer
self
.
_attention_type
=
attention_type
self
.
_num_layers
=
num_layers
self
.
_hidden_size
=
hidden_size
self
.
_num_attention_heads
=
num_attention_heads
self
.
_head_size
=
head_size
self
.
_inner_size
=
inner_size
self
.
_inner_activation
=
inner_activation
self
.
_dropout_rate
=
dropout_rate
self
.
_attention_dropout_rate
=
attention_dropout_rate
self
.
_tie_attention_biases
=
tie_attention_biases
self
.
_two_stream
=
two_stream
self
.
_memory_length
=
memory_length
self
.
_reuse_length
=
reuse_length
self
.
_bi_data
=
bi_data
self
.
_clamp_length
=
clamp_length
self
.
_use_cls_mask
=
use_cls_mask
self
.
_segment_embedding
=
None
self
.
_mask_embedding
=
None
self
.
_embedding_width
=
embedding_width
if
embedding_width
is
None
:
embedding_width
=
hidden_size
self
.
_embedding_layer
=
layers
.
OnDeviceEmbedding
(
vocab_size
=
self
.
_vocab_size
,
embedding_width
=
embedding_width
,
initializer
=
self
.
_initializer
,
dtype
=
tf
.
float32
,
name
=
"word_embedding"
)
self
.
_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
self
.
embedding_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
self
.
position_encoding
=
RelativePositionEncoding
(
self
.
_hidden_size
)
self
.
_transformer_xl
=
transformer_xl
.
TransformerXL
(
vocab_size
=
vocab_size
,
num_layers
=
num_layers
,
hidden_size
=
hidden_size
,
num_attention_heads
=
num_attention_heads
,
head_size
=
head_size
,
inner_size
=
inner_size
,
dropout_rate
=
dropout_rate
,
attention_dropout_rate
=
attention_dropout_rate
,
initializer
=
initializer
,
two_stream
=
two_stream
,
tie_attention_biases
=
tie_attention_biases
,
memory_length
=
memory_length
,
reuse_length
=
reuse_length
,
inner_activation
=
inner_activation
,
name
=
"transformer_xl"
)
def
get_config
(
self
):
config
=
{
"vocab_size"
:
self
.
_vocab_size
,
"num_layers"
:
self
.
_num_layers
,
"hidden_size"
:
self
.
_hidden_size
,
"num_attention_heads"
:
self
.
_num_attention_heads
,
"head_size"
:
self
.
_head_size
,
"inner_size"
:
self
.
_inner_size
,
"dropout_rate"
:
self
.
_dropout_rate
,
"attention_dropout_rate"
:
self
.
_attention_dropout_rate
,
"attention_type"
:
self
.
_attention_type
,
"bi_data"
:
self
.
_bi_data
,
"initializer"
:
self
.
_initializer
,
"two_stream"
:
self
.
_two_stream
,
"tie_attention_biases"
:
self
.
_tie_attention_biases
,
"memory_length"
:
self
.
_memory_length
,
"clamp_length"
:
self
.
_clamp_length
,
"reuse_length"
:
self
.
_reuse_length
,
"inner_activation"
:
self
.
_inner_activation
,
"use_cls_mask"
:
self
.
_use_cls_mask
,
"embedding_width"
:
self
.
_embedding_width
,
}
base_config
=
super
(
XLNetBase
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
get_embedding_lookup_table
(
self
):
"""Returns the embedding layer weights."""
return
self
.
_embedding_layer
.
embeddings
def
__call__
(
self
,
input_ids
,
segment_ids
=
None
,
input_mask
=
None
,
state
=
None
,
permutation_mask
=
None
,
target_mapping
=
None
,
masked_tokens
=
None
,
**
kwargs
):
# Uses dict to feed inputs into call() in order to keep state as a python
# list.
inputs
=
{
"input_ids"
:
input_ids
,
"segment_ids"
:
segment_ids
,
"input_mask"
:
input_mask
,
"state"
:
state
,
"permutation_mask"
:
permutation_mask
,
"target_mapping"
:
target_mapping
,
"masked_tokens"
:
masked_tokens
}
return
super
(
XLNetBase
,
self
).
__call__
(
inputs
,
**
kwargs
)
def
call
(
self
,
inputs
):
"""Implements call() for the layer."""
input_ids
=
inputs
[
"input_ids"
]
segment_ids
=
inputs
[
"segment_ids"
]
input_mask
=
inputs
[
"input_mask"
]
state
=
inputs
[
"state"
]
permutation_mask
=
inputs
[
"permutation_mask"
]
target_mapping
=
inputs
[
"target_mapping"
]
masked_tokens
=
inputs
[
"masked_tokens"
]
batch_size
=
tf
.
shape
(
input_ids
)[
0
]
seq_length
=
input_ids
.
shape
.
as_list
()[
1
]
memory_length
=
state
[
0
].
shape
.
as_list
()[
1
]
if
state
is
not
None
else
0
total_length
=
memory_length
+
seq_length
if
self
.
_two_stream
and
masked_tokens
is
None
:
raise
ValueError
(
"`masked_tokens` must be provided in order to "
"initialize the query stream in "
"`TwoStreamRelativeAttention`."
)
if
masked_tokens
is
not
None
and
not
self
.
_two_stream
:
logging
.
warning
(
"`masked_tokens` is provided but `two_stream` is not "
"enabled. Please enable `two_stream` to enable two "
"stream attention."
)
query_attention_mask
,
content_attention_mask
=
_compute_attention_mask
(
input_mask
=
input_mask
,
permutation_mask
=
permutation_mask
,
attention_type
=
self
.
_attention_type
,
seq_length
=
seq_length
,
memory_length
=
memory_length
,
batch_size
=
batch_size
,
dtype
=
tf
.
float32
)
relative_position_encoding
=
_compute_positional_encoding
(
attention_type
=
self
.
_attention_type
,
position_encoding_layer
=
self
.
position_encoding
,
hidden_size
=
self
.
_hidden_size
,
batch_size
=
batch_size
,
total_length
=
total_length
,
seq_length
=
seq_length
,
clamp_length
=
self
.
_clamp_length
,
bi_data
=
self
.
_bi_data
,
dtype
=
tf
.
float32
)
relative_position_encoding
=
self
.
embedding_dropout
(
relative_position_encoding
)
if
segment_ids
is
None
:
segment_embedding
=
None
segment_matrix
=
None
else
:
if
self
.
_segment_embedding
is
None
:
self
.
_segment_embedding
=
self
.
add_weight
(
"seg_embed"
,
shape
=
[
self
.
_num_layers
,
2
,
self
.
_num_attention_heads
,
self
.
_head_size
],
dtype
=
tf
.
float32
,
initializer
=
self
.
_initializer
)
segment_embedding
=
self
.
_segment_embedding
segment_matrix
=
_compute_segment_matrix
(
segment_ids
=
segment_ids
,
memory_length
=
memory_length
,
batch_size
=
batch_size
,
use_cls_mask
=
self
.
_use_cls_mask
)
word_embeddings
=
self
.
_embedding_layer
(
input_ids
)
content_stream
=
self
.
_dropout
(
word_embeddings
)
if
self
.
_two_stream
:
if
self
.
_mask_embedding
is
None
:
self
.
_mask_embedding
=
self
.
add_weight
(
"mask_emb/mask_emb"
,
shape
=
[
1
,
1
,
self
.
_hidden_size
],
dtype
=
tf
.
float32
)
if
target_mapping
is
None
:
masked_tokens
=
masked_tokens
[:,
:,
None
]
masked_token_embedding
=
(
masked_tokens
*
self
.
_mask_embedding
+
(
1
-
masked_tokens
)
*
word_embeddings
)
else
:
masked_token_embedding
=
tf
.
tile
(
self
.
_mask_embedding
,
[
batch_size
,
tf
.
shape
(
target_mapping
)[
1
],
1
])
query_stream
=
self
.
_dropout
(
masked_token_embedding
)
else
:
query_stream
=
None
return
self
.
_transformer_xl
(
content_stream
=
content_stream
,
query_stream
=
query_stream
,
target_mapping
=
target_mapping
,
state
=
state
,
relative_position_encoding
=
relative_position_encoding
,
segment_matrix
=
segment_matrix
,
segment_embedding
=
segment_embedding
,
content_attention_mask
=
content_attention_mask
,
query_attention_mask
=
query_attention_mask
)
official/nlp/modeling/networks/xlnet_base_test.py
0 → 100644
View file @
b0ccdb11
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.keras
import
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
from
official.nlp.modeling.networks
import
xlnet_base
@
keras_parameterized
.
run_all_keras_modes
class
RelativePositionEncodingTest
(
keras_parameterized
.
TestCase
):
def
test_positional_embedding
(
self
):
"""A low-dimensional example is tested.
With len(pos_seq)=2 and d_model=4:
pos_seq = [[1.], [0.]]
inv_freq = [1., 0.01]
pos_seq x inv_freq = [[1, 0.01], [0., 0.]]
pos_emb = [[sin(1.), sin(0.01), cos(1.), cos(0.01)],
[sin(0.), sin(0.), cos(0.), cos(0.)]]
= [[0.84147096, 0.00999983, 0.54030228, 0.99994999],
[0., 0., 1., 1.]]
"""
target
=
np
.
array
([[[
0.84147096
,
0.00999983
,
0.54030228
,
0.99994999
],
[
0.
,
0.
,
1.
,
1.
]]])
hidden_size
=
4
pos_seq
=
tf
.
range
(
1
,
-
1
,
-
1.0
)
# [1., 0.]
encoding_layer
=
xlnet_base
.
RelativePositionEncoding
(
hidden_size
=
hidden_size
)
encoding
=
encoding_layer
(
pos_seq
,
batch_size
=
None
).
numpy
().
astype
(
float
)
self
.
assertAllClose
(
encoding
,
target
)
class
ComputePositionEncodingTest
(
keras_parameterized
.
TestCase
):
@
combinations
.
generate
(
combinations
.
combine
(
attention_type
=
[
"uni"
,
"bi"
],
bi_data
=
[
False
,
True
],
))
def
test_compute_position_encoding_smoke
(
self
,
attention_type
,
bi_data
):
hidden_size
=
4
batch_size
=
4
total_length
=
8
seq_length
=
4
position_encoding_layer
=
xlnet_base
.
RelativePositionEncoding
(
hidden_size
=
hidden_size
)
encoding
=
xlnet_base
.
_compute_positional_encoding
(
attention_type
=
attention_type
,
position_encoding_layer
=
position_encoding_layer
,
hidden_size
=
hidden_size
,
batch_size
=
batch_size
,
total_length
=
total_length
,
seq_length
=
seq_length
,
clamp_length
=
2
,
bi_data
=
bi_data
,
dtype
=
tf
.
float32
)
self
.
assertEqual
(
encoding
.
shape
[
0
],
batch_size
)
self
.
assertEqual
(
encoding
.
shape
[
2
],
hidden_size
)
class
CausalAttentionMaskTests
(
tf
.
test
.
TestCase
):
def
test_casual_attention_mask_with_no_memory
(
self
):
seq_length
,
memory_length
=
3
,
0
causal_attention_mask
=
xlnet_base
.
_create_causal_attention_mask
(
seq_length
=
seq_length
,
memory_length
=
memory_length
)
expected_output
=
np
.
array
([[
0
,
1
,
1
],
[
0
,
0
,
1
],
[
0
,
0
,
0
]])
self
.
assertAllClose
(
causal_attention_mask
,
expected_output
)
def
test_casual_attention_mask_with_memory
(
self
):
seq_length
,
memory_length
=
3
,
2
causal_attention_mask
=
xlnet_base
.
_create_causal_attention_mask
(
seq_length
=
seq_length
,
memory_length
=
memory_length
)
expected_output
=
np
.
array
([[
0
,
0
,
0
,
1
,
1
],
[
0
,
0
,
0
,
0
,
1
],
[
0
,
0
,
0
,
0
,
0
]])
self
.
assertAllClose
(
causal_attention_mask
,
expected_output
)
def
test_causal_attention_mask_with_same_length
(
self
):
seq_length
,
memory_length
=
3
,
2
causal_attention_mask
=
xlnet_base
.
_create_causal_attention_mask
(
seq_length
=
seq_length
,
memory_length
=
memory_length
,
same_length
=
True
)
expected_output
=
np
.
array
([[
0
,
0
,
0
,
1
,
1
],
[
1
,
0
,
0
,
0
,
1
],
[
1
,
1
,
0
,
0
,
0
]])
self
.
assertAllClose
(
causal_attention_mask
,
expected_output
)
class
MaskComputationTests
(
keras_parameterized
.
TestCase
):
@
combinations
.
generate
(
combinations
.
combine
(
use_input_mask
=
[
False
,
True
],
use_permutation_mask
=
[
False
,
True
],
attention_type
=
[
"uni"
,
"bi"
],
memory_length
=
[
0
,
4
],
))
def
test_compute_attention_mask_smoke
(
self
,
use_input_mask
,
use_permutation_mask
,
attention_type
,
memory_length
):
"""Tests coverage and functionality for different configurations."""
batch_size
=
2
seq_length
=
8
if
use_input_mask
:
input_mask
=
tf
.
zeros
(
shape
=
(
batch_size
,
seq_length
))
else
:
input_mask
=
None
if
use_permutation_mask
:
permutation_mask
=
tf
.
zeros
(
shape
=
(
batch_size
,
seq_length
,
seq_length
))
else
:
permutation_mask
=
None
_
,
content_mask
=
xlnet_base
.
_compute_attention_mask
(
input_mask
=
input_mask
,
permutation_mask
=
permutation_mask
,
attention_type
=
attention_type
,
seq_length
=
seq_length
,
memory_length
=
memory_length
,
batch_size
=
batch_size
,
dtype
=
tf
.
float32
)
expected_mask_shape
=
(
batch_size
,
1
,
seq_length
,
seq_length
+
memory_length
)
if
use_input_mask
or
use_permutation_mask
:
self
.
assertEqual
(
content_mask
.
shape
,
expected_mask_shape
)
def
test_no_input_masks
(
self
):
query_mask
,
content_mask
=
xlnet_base
.
_compute_attention_mask
(
input_mask
=
None
,
permutation_mask
=
None
,
attention_type
=
"uni"
,
seq_length
=
8
,
memory_length
=
2
,
batch_size
=
2
,
dtype
=
tf
.
float32
)
self
.
assertIsNone
(
query_mask
)
self
.
assertIsNone
(
content_mask
)
def
test_input_mask_no_permutation
(
self
):
"""Tests if an input mask is provided but not permutation.
In the case that only one of input mask or permutation mask is provided
and the attention type is bidirectional, the query mask should be
a broadcasted version of the provided mask.
Content mask should be a broadcasted version of the query mask, where the
diagonal is 0s.
"""
seq_length
=
4
batch_size
=
1
memory_length
=
0
input_mask
=
np
.
array
([[
0
,
0
,
1
,
1
]])
permutation_mask
=
None
expected_query_mask
=
input_mask
[
None
,
None
,
:,
:]
expected_content_mask
=
np
.
array
([[[
[
0
,
0
,
1
,
1
],
[
0
,
0
,
1
,
1
],
[
0
,
0
,
0
,
1
],
[
0
,
0
,
1
,
0
]]]])
query_mask
,
content_mask
=
xlnet_base
.
_compute_attention_mask
(
input_mask
=
input_mask
,
permutation_mask
=
permutation_mask
,
attention_type
=
"bi"
,
seq_length
=
seq_length
,
memory_length
=
memory_length
,
batch_size
=
batch_size
,
dtype
=
tf
.
float32
)
self
.
assertAllClose
(
query_mask
,
expected_query_mask
)
self
.
assertAllClose
(
content_mask
,
expected_content_mask
)
def
test_permutation_mask_no_input_mask
(
self
):
"""Tests if a permutation mask is provided but not input."""
seq_length
=
2
batch_size
=
1
memory_length
=
0
input_mask
=
None
permutation_mask
=
np
.
array
([
[[
0
,
1
],
[
0
,
1
]],
])
expected_query_mask
=
permutation_mask
[:,
None
,
:,
:]
expected_content_mask
=
np
.
array
([[[
[
0
,
1
],
[
0
,
0
]]]])
query_mask
,
content_mask
=
xlnet_base
.
_compute_attention_mask
(
input_mask
=
input_mask
,
permutation_mask
=
permutation_mask
,
attention_type
=
"bi"
,
seq_length
=
seq_length
,
memory_length
=
memory_length
,
batch_size
=
batch_size
,
dtype
=
tf
.
float32
)
self
.
assertAllClose
(
query_mask
,
expected_query_mask
)
self
.
assertAllClose
(
content_mask
,
expected_content_mask
)
def
test_permutation_and_input_mask
(
self
):
"""Tests if both an input and permutation mask are provided."""
seq_length
=
4
batch_size
=
1
memory_length
=
0
input_mask
=
np
.
array
([[
0
,
0
,
1
,
1
]])
permutation_mask
=
np
.
array
([[
[
1
,
0
,
0
,
0
],
[
0
,
1
,
0
,
0
],
[
0
,
0
,
1
,
0
],
[
0
,
0
,
0
,
1
],
]])
expected_query_mask
=
np
.
array
([[[
[
1
,
0
,
1
,
1
],
[
0
,
1
,
1
,
1
],
[
0
,
0
,
1
,
1
],
[
0
,
0
,
1
,
1
]]]])
expected_content_mask
=
np
.
array
([[[
[
0
,
0
,
1
,
1
],
[
0
,
0
,
1
,
1
],
[
0
,
0
,
0
,
1
],
[
0
,
0
,
1
,
0
]]]])
query_mask
,
content_mask
=
xlnet_base
.
_compute_attention_mask
(
input_mask
=
input_mask
,
permutation_mask
=
permutation_mask
,
attention_type
=
"bi"
,
seq_length
=
seq_length
,
memory_length
=
memory_length
,
batch_size
=
batch_size
,
dtype
=
tf
.
float32
)
self
.
assertAllClose
(
query_mask
,
expected_query_mask
)
self
.
assertAllClose
(
content_mask
,
expected_content_mask
)
def
test_permutation_input_uni_mask
(
self
):
"""Tests if an input, permutation and causal mask are provided."""
seq_length
=
4
batch_size
=
1
memory_length
=
0
input_mask
=
np
.
array
([[
0
,
0
,
0
,
1
]])
permutation_mask
=
np
.
array
([[
[
1
,
0
,
0
,
0
],
[
0
,
1
,
0
,
0
],
[
0
,
0
,
1
,
0
],
[
0
,
0
,
0
,
1
],
]])
expected_query_mask
=
np
.
array
([[[
[
1
,
1
,
1
,
1
],
[
0
,
1
,
1
,
1
],
[
0
,
0
,
1
,
1
],
[
0
,
0
,
0
,
1
]]]])
expected_content_mask
=
np
.
array
([[[
[
0
,
1
,
1
,
1
],
[
0
,
0
,
1
,
1
],
[
0
,
0
,
0
,
1
],
[
0
,
0
,
0
,
0
]]]])
query_mask
,
content_mask
=
xlnet_base
.
_compute_attention_mask
(
input_mask
=
input_mask
,
permutation_mask
=
permutation_mask
,
attention_type
=
"uni"
,
seq_length
=
seq_length
,
memory_length
=
memory_length
,
batch_size
=
batch_size
,
dtype
=
tf
.
float32
)
self
.
assertAllClose
(
query_mask
,
expected_query_mask
)
self
.
assertAllClose
(
content_mask
,
expected_content_mask
)
class
SegmentMatrixTests
(
tf
.
test
.
TestCase
):
def
test_no_segment_ids
(
self
):
segment_matrix
=
xlnet_base
.
_compute_segment_matrix
(
segment_ids
=
None
,
memory_length
=
2
,
batch_size
=
1
,
use_cls_mask
=
False
)
self
.
assertIsNone
(
segment_matrix
)
def
test_basic
(
self
):
batch_size
=
1
memory_length
=
0
segment_ids
=
np
.
array
([
[
1
,
1
,
2
,
1
]
])
expected_segment_matrix
=
np
.
array
([[
[
False
,
False
,
True
,
False
],
[
False
,
False
,
True
,
False
],
[
True
,
True
,
False
,
True
],
[
False
,
False
,
True
,
False
]
]])
segment_matrix
=
xlnet_base
.
_compute_segment_matrix
(
segment_ids
=
segment_ids
,
memory_length
=
memory_length
,
batch_size
=
batch_size
,
use_cls_mask
=
False
)
self
.
assertAllClose
(
segment_matrix
,
expected_segment_matrix
)
def
test_basic_with_memory
(
self
):
batch_size
=
1
memory_length
=
1
segment_ids
=
np
.
array
([
[
1
,
1
,
2
,
1
]
])
expected_segment_matrix
=
np
.
array
([[
[
True
,
False
,
False
,
True
,
False
],
[
True
,
False
,
False
,
True
,
False
],
[
True
,
True
,
True
,
False
,
True
],
[
True
,
False
,
False
,
True
,
False
]
]]).
astype
(
int
)
segment_matrix
=
tf
.
cast
(
xlnet_base
.
_compute_segment_matrix
(
segment_ids
=
segment_ids
,
memory_length
=
memory_length
,
batch_size
=
batch_size
,
use_cls_mask
=
False
),
dtype
=
tf
.
uint8
)
self
.
assertAllClose
(
segment_matrix
,
expected_segment_matrix
)
def
dont_test_basic_with_class_mask
(
self
):
# TODO(allencwang) - this test should pass but illustrates the legacy issue
# of using class mask. Enable once addressed.
batch_size
=
1
memory_length
=
0
segment_ids
=
np
.
array
([
[
1
,
1
,
2
,
1
]
])
expected_segment_matrix
=
np
.
array
([[
[
False
,
False
,
True
,
False
],
[
False
,
False
,
True
,
False
],
[
True
,
True
,
False
,
True
],
[
False
,
False
,
True
,
False
]
]]).
astype
(
int
)
segment_matrix
=
tf
.
cast
(
xlnet_base
.
_compute_segment_matrix
(
segment_ids
=
segment_ids
,
memory_length
=
memory_length
,
batch_size
=
batch_size
,
use_cls_mask
=
True
),
dtype
=
tf
.
uint8
)
self
.
assertAllClose
(
segment_matrix
,
expected_segment_matrix
)
class
XLNetModelTests
(
tf
.
test
.
TestCase
):
def
_generate_data
(
self
,
batch_size
,
seq_length
,
num_predictions
=
None
):
"""Generates sample XLNet data for testing."""
sequence_shape
=
(
batch_size
,
seq_length
)
if
num_predictions
is
not
None
:
target_mapping
=
tf
.
random
.
uniform
(
shape
=
(
batch_size
,
num_predictions
,
seq_length
))
return
{
"input_ids"
:
np
.
random
.
randint
(
10
,
size
=
sequence_shape
,
dtype
=
"int32"
),
"segment_ids"
:
np
.
random
.
randint
(
2
,
size
=
sequence_shape
,
dtype
=
"int32"
),
"input_mask"
:
np
.
random
.
randint
(
2
,
size
=
sequence_shape
).
astype
(
"float32"
),
"permutation_mask"
:
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
seq_length
,
seq_length
)).
astype
(
"float32"
),
"target_mapping"
:
target_mapping
,
"masked_tokens"
:
tf
.
random
.
uniform
(
shape
=
sequence_shape
),
}
def
test_xlnet_model
(
self
):
batch_size
=
2
seq_length
=
8
num_predictions
=
2
hidden_size
=
4
xlnet_model
=
xlnet_base
.
XLNetBase
(
vocab_size
=
32000
,
num_layers
=
2
,
hidden_size
=
hidden_size
,
num_attention_heads
=
2
,
head_size
=
2
,
inner_size
=
2
,
dropout_rate
=
0.
,
attention_dropout_rate
=
0.
,
attention_type
=
"bi"
,
bi_data
=
True
,
initializer
=
tf
.
keras
.
initializers
.
RandomNormal
(
stddev
=
0.1
),
two_stream
=
False
,
tie_attention_biases
=
True
,
reuse_length
=
0
,
inner_activation
=
"relu"
)
input_data
=
self
.
_generate_data
(
batch_size
=
batch_size
,
seq_length
=
seq_length
,
num_predictions
=
num_predictions
)
model_output
=
xlnet_model
(
**
input_data
)
self
.
assertEqual
(
model_output
[
0
].
shape
,
(
batch_size
,
seq_length
,
hidden_size
))
def
test_get_config
(
self
):
xlnet_model
=
xlnet_base
.
XLNetBase
(
vocab_size
=
32000
,
num_layers
=
12
,
hidden_size
=
36
,
num_attention_heads
=
12
,
head_size
=
12
,
inner_size
=
12
,
dropout_rate
=
0.
,
attention_dropout_rate
=
0.
,
attention_type
=
"bi"
,
bi_data
=
True
,
initializer
=
tf
.
keras
.
initializers
.
RandomNormal
(
stddev
=
0.1
),
two_stream
=
False
,
tie_attention_biases
=
True
,
memory_length
=
0
,
reuse_length
=
0
,
inner_activation
=
"relu"
)
config
=
xlnet_model
.
get_config
()
new_xlnet
=
xlnet_base
.
XLNetBase
.
from_config
(
config
)
self
.
assertEqual
(
config
,
new_xlnet
.
get_config
())
if
__name__
==
"__main__"
:
tf
.
random
.
set_seed
(
0
)
tf
.
test
.
main
()
official/nlp/modeling/ops/beam_search_test.py
View file @
b0ccdb11
...
...
@@ -14,12 +14,13 @@
# ==============================================================================
"""Test beam search helper methods."""
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.nlp.modeling.ops
import
beam_search
class
BeamSearch
Helper
Tests
(
tf
.
test
.
TestCase
):
class
BeamSearchTests
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
test_expand_to_beam_size
(
self
):
x
=
tf
.
ones
([
7
,
4
,
2
,
5
])
...
...
@@ -67,6 +68,41 @@ class BeamSearchHelperTests(tf.test.TestCase):
[[[
4
,
5
,
6
,
7
],
[
8
,
9
,
10
,
11
]],
[[
12
,
13
,
14
,
15
],
[
20
,
21
,
22
,
23
]]],
y
)
@
parameterized
.
named_parameters
([
(
'padded_decode_true'
,
True
),
(
'padded_decode_false'
,
False
),
])
def
test_sequence_beam_search
(
self
,
padded_decode
):
# batch_size*beam_size, max_decode_length, vocab_size
probabilities
=
tf
.
constant
([[[
0.2
,
0.7
,
0.1
],
[
0.5
,
0.3
,
0.2
],
[
0.1
,
0.8
,
0.1
]],
[[
0.1
,
0.8
,
0.1
],
[
0.3
,
0.4
,
0.3
],
[
0.2
,
0.1
,
0.7
]]])
# batch_size, max_decode_length, num_heads, embed_size per head
x
=
tf
.
zeros
([
1
,
3
,
2
,
32
],
dtype
=
tf
.
float32
)
cache
=
{
'layer_%d'
%
layer
:
{
'k'
:
x
,
'v'
:
x
}
for
layer
in
range
(
2
)}
if
__name__
==
"__main__"
:
def
_get_test_symbols_to_logits_fn
():
"""Test function that returns logits for next token."""
def
symbols_to_logits_fn
(
_
,
i
,
cache
):
logits
=
tf
.
cast
(
probabilities
[:,
i
,
:],
tf
.
float32
)
return
logits
,
cache
return
symbols_to_logits_fn
predictions
,
_
=
beam_search
.
sequence_beam_search
(
symbols_to_logits_fn
=
_get_test_symbols_to_logits_fn
(),
initial_ids
=
tf
.
zeros
([
1
],
dtype
=
tf
.
int32
),
initial_cache
=
cache
,
vocab_size
=
3
,
beam_size
=
2
,
alpha
=
0.6
,
max_decode_length
=
3
,
eos_id
=
9
,
padded_decode
=
padded_decode
,
dtype
=
tf
.
float32
)
self
.
assertAllEqual
([[[
0
,
1
,
0
,
1
],
[
0
,
1
,
1
,
2
]]],
predictions
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/nlp/nhnet/README.md
View file @
b0ccdb11
...
...
@@ -104,14 +104,13 @@ Please first install TensorFlow 2 and Tensorflow Model Garden following the
```
shell
$
python3 trainer.py
\
--mode
=
train_and_eval
\
--vocab
=
/path/to/bert_checkpoint/vocab.txt
\
--init_checkpoint
=
/path/to/bert_checkpoint/bert_model.ckpt
\
--params_override
=
'init_from_bert2bert=false'
\
--train_file_pattern
=
$DATA_FOLDER
/processed/train.tfrecord
*
\
--model_dir
=
/path/to/output/model
\
--len_title
=
15
\
--len_passage
=
200
\
--
max_num
_articles
=
5
\
--
num_nhnet
_articles
=
5
\
--model_type
=
nhnet
\
--train_batch_size
=
16
\
--train_steps
=
10000
\
...
...
@@ -123,14 +122,13 @@ $ python3 trainer.py \
```
shell
$
python3 trainer.py
\
--mode
=
train_and_eval
\
--vocab
=
/path/to/bert_checkpoint/vocab.txt
\
--init_checkpoint
=
/path/to/bert_checkpoint/bert_model.ckpt
\
--params_override
=
'init_from_bert2bert=false'
\
--train_file_pattern
=
$DATA_FOLDER
/processed/train.tfrecord
*
\
--model_dir
=
/path/to/output/model
\
--len_title
=
15
\
--len_passage
=
200
\
--
max_num
_articles
=
5
\
--
num_nhnet
_articles
=
5
\
--model_type
=
nhnet
\
--train_batch_size
=
1024
\
--train_steps
=
10000
\
...
...
official/nlp/nhnet/decoder.py
View file @
b0ccdb11
...
...
@@ -22,7 +22,6 @@ from __future__ import print_function
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.nlp.modeling
import
layers
from
official.nlp.modeling.layers
import
transformer
from
official.nlp.transformer
import
model_utils
as
transformer_utils
...
...
@@ -59,7 +58,7 @@ class TransformerDecoder(tf.keras.layers.Layer):
self
.
layers
=
[]
for
i
in
range
(
self
.
num_hidden_layers
):
self
.
layers
.
append
(
transform
er
.
TransformerDecoder
Layer
(
lay
er
s
.
TransformerDecoder
Block
(
num_attention_heads
=
self
.
num_attention_heads
,
intermediate_size
=
self
.
intermediate_size
,
intermediate_activation
=
self
.
intermediate_activation
,
...
...
official/nlp/nhnet/evaluation.py
View file @
b0ccdb11
...
...
@@ -15,11 +15,6 @@
# ==============================================================================
"""Evaluation for Bert2Bert."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
os
# Import libraries
from
absl
import
logging
...
...
@@ -114,7 +109,6 @@ def continuous_eval(strategy,
dtype
=
tf
.
int64
,
aggregation
=
tf
.
VariableAggregation
.
ONLY_FIRST_REPLICA
,
shape
=
[])
model
.
global_step
=
global_step
@
tf
.
function
def
test_step
(
inputs
):
...
...
@@ -149,7 +143,7 @@ def continuous_eval(strategy,
eval_results
=
{}
for
latest_checkpoint
in
tf
.
train
.
checkpoints_iterator
(
model_dir
,
timeout
=
timeout
):
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
model
)
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
model
,
global_step
=
global_step
)
checkpoint
.
restore
(
latest_checkpoint
).
expect_partial
()
logging
.
info
(
"Loaded checkpoint %s"
,
latest_checkpoint
)
...
...
@@ -162,7 +156,7 @@ def continuous_eval(strategy,
metric
.
update_state
(
func
(
logits
.
numpy
(),
targets
.
numpy
()))
with
eval_summary_writer
.
as_default
():
step
=
model
.
global_step
.
numpy
()
step
=
global_step
.
numpy
()
for
metric
,
_
in
metrics_and_funcs
:
eval_results
[
metric
.
name
]
=
metric
.
result
().
numpy
().
astype
(
float
)
tf
.
summary
.
scalar
(
...
...
official/nlp/nhnet/trainer.py
View file @
b0ccdb11
...
...
@@ -27,13 +27,13 @@ from absl import flags
from
absl
import
logging
from
six.moves
import
zip
import
tensorflow
as
tf
from
official.common
import
distribute_utils
from
official.modeling.hyperparams
import
params_dict
from
official.nlp.nhnet
import
evaluation
from
official.nlp.nhnet
import
input_pipeline
from
official.nlp.nhnet
import
models
from
official.nlp.nhnet
import
optimizer
from
official.nlp.transformer
import
metrics
as
transformer_metrics
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
keras_utils
FLAGS
=
flags
.
FLAGS
...
...
@@ -145,7 +145,6 @@ def train(params, strategy, dataset=None):
FLAGS
.
model_type
,
params
,
init_checkpoint
=
FLAGS
.
init_checkpoint
)
opt
=
optimizer
.
create_optimizer
(
params
)
trainer
=
Trainer
(
model
,
params
)
model
.
global_step
=
opt
.
iterations
trainer
.
compile
(
optimizer
=
opt
,
...
...
@@ -153,12 +152,13 @@ def train(params, strategy, dataset=None):
summary_dir
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
"summaries"
)
summary_callback
=
tf
.
keras
.
callbacks
.
TensorBoard
(
summary_dir
,
update_freq
=
max
(
100
,
FLAGS
.
steps_per_loop
))
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
model
,
optimizer
=
opt
)
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
model
,
optimizer
=
opt
,
global_step
=
opt
.
iterations
)
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
checkpoint
,
directory
=
FLAGS
.
model_dir
,
max_to_keep
=
10
,
step_counter
=
model
.
global_step
,
step_counter
=
opt
.
iterations
,
checkpoint_interval
=
FLAGS
.
checkpoint_interval
)
if
checkpoint_manager
.
restore_or_initialize
():
logging
.
info
(
"Training restored from the checkpoints in: %s"
,
...
...
@@ -185,7 +185,7 @@ def run():
if
FLAGS
.
enable_mlir_bridge
:
tf
.
config
.
experimental
.
enable_mlir_bridge
()
strategy
=
distribut
ion
_utils
.
get_distribution_strategy
(
strategy
=
distribut
e
_utils
.
get_distribution_strategy
(
distribution_strategy
=
FLAGS
.
distribution_strategy
,
tpu_address
=
FLAGS
.
tpu
)
if
strategy
:
logging
.
info
(
"***** Number of cores used : %d"
,
...
...
official/nlp/projects/bigbird/attention.py
0 → 100644
View file @
b0ccdb11
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras-based bigbird attention layer."""
import
numpy
as
np
import
tensorflow
as
tf
MAX_SEQ_LEN
=
4096
def
create_band_mask_from_inputs
(
from_blocked_mask
,
to_blocked_mask
):
"""Create 3D attention mask from a 2D tensor mask.
Args:
from_blocked_mask: 2D Tensor of shape [batch_size,
from_seq_length//from_block_size, from_block_size].
to_blocked_mask: int32 Tensor of shape [batch_size,
to_seq_length//to_block_size, to_block_size].
Returns:
float Tensor of shape [batch_size, 1, from_seq_length//from_block_size-4,
from_block_size, 3*to_block_size].
"""
exp_blocked_to_pad
=
tf
.
concat
([
to_blocked_mask
[:,
1
:
-
3
],
to_blocked_mask
[:,
2
:
-
2
],
to_blocked_mask
[:,
3
:
-
1
]
],
2
)
band_mask
=
tf
.
einsum
(
"BLQ,BLK->BLQK"
,
from_blocked_mask
[:,
2
:
-
2
],
exp_blocked_to_pad
)
band_mask
=
tf
.
expand_dims
(
band_mask
,
1
)
return
band_mask
def
bigbird_block_rand_mask
(
from_seq_length
,
to_seq_length
,
from_block_size
,
to_block_size
,
num_rand_blocks
,
last_idx
=-
1
):
"""Create adjacency list of random attention.
Args:
from_seq_length: int. length of from sequence.
to_seq_length: int. length of to sequence.
from_block_size: int. size of block in from sequence.
to_block_size: int. size of block in to sequence.
num_rand_blocks: int. Number of random chunks per row.
last_idx: if -1 then num_rand_blocks blocks chosen anywhere in to sequence,
if positive then num_rand_blocks blocks choosen only upto last_idx.
Returns:
adjacency list of size from_seq_length//from_block_size-2 by num_rand_blocks
"""
assert
from_seq_length
//
from_block_size
==
to_seq_length
//
to_block_size
,
\
"Error the number of blocks needs to be same!"
rand_attn
=
np
.
zeros
(
(
from_seq_length
//
from_block_size
-
2
,
num_rand_blocks
),
dtype
=
np
.
int32
)
middle_seq
=
np
.
arange
(
1
,
to_seq_length
//
to_block_size
-
1
,
dtype
=
np
.
int32
)
last
=
to_seq_length
//
to_block_size
-
1
if
last_idx
>
(
2
*
to_block_size
):
last
=
(
last_idx
//
to_block_size
)
-
1
r
=
num_rand_blocks
# shorthand
for
i
in
range
(
1
,
from_seq_length
//
from_block_size
-
1
):
start
=
i
-
2
end
=
i
if
i
==
1
:
rand_attn
[
i
-
1
,
:]
=
np
.
random
.
permutation
(
middle_seq
[
2
:
last
])[:
r
]
elif
i
==
2
:
rand_attn
[
i
-
1
,
:]
=
np
.
random
.
permutation
(
middle_seq
[
3
:
last
])[:
r
]
elif
i
==
from_seq_length
//
from_block_size
-
3
:
rand_attn
[
i
-
1
,
:]
=
np
.
random
.
permutation
(
middle_seq
[:
last
])[:
r
]
# Missing -3: should have been sliced till last-3
elif
i
==
from_seq_length
//
from_block_size
-
2
:
rand_attn
[
i
-
1
,
:]
=
np
.
random
.
permutation
(
middle_seq
[:
last
])[:
r
]
# Missing -4: should have been sliced till last-4
else
:
if
start
>
last
:
start
=
last
rand_attn
[
i
-
1
,
:]
=
np
.
random
.
permutation
(
middle_seq
[:
start
])[:
r
]
elif
(
end
+
1
)
==
last
:
rand_attn
[
i
-
1
,
:]
=
np
.
random
.
permutation
(
middle_seq
[:
start
])[:
r
]
else
:
rand_attn
[
i
-
1
,
:]
=
np
.
random
.
permutation
(
np
.
concatenate
((
middle_seq
[:
start
],
middle_seq
[
end
+
1
:
last
])))[:
r
]
return
rand_attn
def
create_rand_mask_from_inputs
(
from_blocked_mask
,
to_blocked_mask
,
rand_attn
,
num_attention_heads
,
num_rand_blocks
,
batch_size
,
from_seq_length
,
from_block_size
):
"""Create 3D attention mask from a 2D tensor mask.
Args:
from_blocked_mask: 2D Tensor of shape [batch_size,
from_seq_length//from_block_size, from_block_size].
to_blocked_mask: int32 Tensor of shape [batch_size,
to_seq_length//to_block_size, to_block_size].
rand_attn: [batch_size, num_attention_heads,
from_seq_length//from_block_size-2, num_rand_blocks]
num_attention_heads: int. Number of attention heads.
num_rand_blocks: int. Number of random chunks per row.
batch_size: int. Batch size for computation.
from_seq_length: int. length of from sequence.
from_block_size: int. size of block in from sequence.
Returns:
float Tensor of shape [batch_size, num_attention_heads,
from_seq_length//from_block_size-2,
from_block_size, num_rand_blocks*to_block_size].
"""
num_windows
=
from_seq_length
//
from_block_size
-
2
rand_mask
=
tf
.
reshape
(
tf
.
gather
(
to_blocked_mask
,
rand_attn
,
batch_dims
=
1
),
[
batch_size
,
num_attention_heads
,
num_windows
,
num_rand_blocks
*
from_block_size
])
rand_mask
=
tf
.
einsum
(
"BLQ,BHLK->BHLQK"
,
from_blocked_mask
[:,
1
:
-
1
],
rand_mask
)
return
rand_mask
def
bigbird_block_sparse_attention
(
query_layer
,
key_layer
,
value_layer
,
band_mask
,
from_mask
,
to_mask
,
from_blocked_mask
,
to_blocked_mask
,
rand_attn
,
num_attention_heads
,
num_rand_blocks
,
size_per_head
,
batch_size
,
from_seq_length
,
to_seq_length
,
from_block_size
,
to_block_size
):
"""BigBird attention sparse calculation using blocks in linear time.
Assumes from_seq_length//from_block_size == to_seq_length//to_block_size.
Args:
query_layer: float Tensor of shape [batch_size, num_attention_heads,
from_seq_length, size_per_head]
key_layer: float Tensor of shape [batch_size, num_attention_heads,
to_seq_length, size_per_head]
value_layer: float Tensor of shape [batch_size, num_attention_heads,
to_seq_length, size_per_head]
band_mask: (optional) int32 Tensor of shape [batch_size, 1,
from_seq_length//from_block_size-4, from_block_size, 3*to_block_size]. The
values should be 1 or 0. The attention scores will effectively be set to
-infinity for any positions in the mask that are 0, and will be unchanged
for positions that are 1.
from_mask: (optional) int32 Tensor of shape [batch_size, 1, from_seq_length,
1]. The values should be 1 or 0. The attention scores will effectively be
set to -infinity for any positions in the mask that are 0, and will be
unchanged for positions that are 1.
to_mask: (optional) int32 Tensor of shape [batch_size, 1, 1, to_seq_length].
The values should be 1 or 0. The attention scores will effectively be set
to -infinity for any positions in the mask that are 0, and will be
unchanged for positions that are 1.
from_blocked_mask: (optional) int32 Tensor of shape [batch_size,
from_seq_length//from_block_size, from_block_size]. Same as from_mask,
just reshaped.
to_blocked_mask: (optional) int32 Tensor of shape [batch_size,
to_seq_length//to_block_size, to_block_size]. Same as to_mask, just
reshaped.
rand_attn: [batch_size, num_attention_heads,
from_seq_length//from_block_size-2, num_rand_blocks]
num_attention_heads: int. Number of attention heads.
num_rand_blocks: int. Number of random chunks per row.
size_per_head: int. Size of each attention head.
batch_size: int. Batch size for computation.
from_seq_length: int. length of from sequence.
to_seq_length: int. length of to sequence.
from_block_size: int. size of block in from sequence.
to_block_size: int. size of block in to sequence.
Returns:
float Tensor of shape [batch_size, from_seq_length, num_attention_heads,
size_per_head].
"""
rand_attn
=
tf
.
expand_dims
(
rand_attn
,
0
)
rand_attn
=
tf
.
repeat
(
rand_attn
,
batch_size
,
0
)
rand_mask
=
create_rand_mask_from_inputs
(
from_blocked_mask
,
to_blocked_mask
,
rand_attn
,
num_attention_heads
,
num_rand_blocks
,
batch_size
,
from_seq_length
,
from_block_size
,
)
# Define shorthands
h
=
num_attention_heads
r
=
num_rand_blocks
d
=
size_per_head
b
=
batch_size
m
=
from_seq_length
n
=
to_seq_length
wm
=
from_block_size
wn
=
to_block_size
query_layer
=
tf
.
transpose
(
query_layer
,
perm
=
[
0
,
2
,
1
,
3
])
key_layer
=
tf
.
transpose
(
key_layer
,
perm
=
[
0
,
2
,
1
,
3
])
value_layer
=
tf
.
transpose
(
value_layer
,
perm
=
[
0
,
2
,
1
,
3
])
blocked_query_matrix
=
tf
.
reshape
(
query_layer
,
(
b
,
h
,
m
//
wm
,
wm
,
-
1
))
blocked_key_matrix
=
tf
.
reshape
(
key_layer
,
(
b
,
h
,
n
//
wn
,
wn
,
-
1
))
blocked_value_matrix
=
tf
.
reshape
(
value_layer
,
(
b
,
h
,
n
//
wn
,
wn
,
-
1
))
gathered_key
=
tf
.
reshape
(
tf
.
gather
(
blocked_key_matrix
,
rand_attn
,
batch_dims
=
2
,
name
=
"gather_key"
),
(
b
,
h
,
m
//
wm
-
2
,
r
*
wn
,
-
1
))
# [b, h, n//wn-2, r, wn, -1]
gathered_value
=
tf
.
reshape
(
tf
.
gather
(
blocked_value_matrix
,
rand_attn
,
batch_dims
=
2
,
name
=
"gather_value"
),
(
b
,
h
,
m
//
wm
-
2
,
r
*
wn
,
-
1
))
# [b, h, n//wn-2, r, wn, -1]
first_product
=
tf
.
einsum
(
"BHQD,BHKD->BHQK"
,
blocked_query_matrix
[:,
:,
0
],
key_layer
)
# [b, h, wm, -1] x [b, h, n, -1] ==> [b, h, wm, n]
first_product
=
tf
.
multiply
(
first_product
,
1.0
/
np
.
sqrt
(
d
))
first_product
+=
(
1.0
-
tf
.
cast
(
to_mask
,
dtype
=
tf
.
float32
))
*
-
10000.0
first_attn_weights
=
tf
.
nn
.
softmax
(
first_product
)
# [b, h, wm, n]
first_context_layer
=
tf
.
einsum
(
"BHQK,BHKD->BHQD"
,
first_attn_weights
,
value_layer
)
# [b, h, wm, n] x [b, h, n, -1] ==> [b, h, wm, -1]
first_context_layer
=
tf
.
expand_dims
(
first_context_layer
,
2
)
second_key_mat
=
tf
.
concat
([
blocked_key_matrix
[:,
:,
0
],
blocked_key_matrix
[:,
:,
1
],
blocked_key_matrix
[:,
:,
2
],
blocked_key_matrix
[:,
:,
-
1
],
gathered_key
[:,
:,
0
]
],
2
)
# [b, h, (4+r)*wn, -1]
second_value_mat
=
tf
.
concat
([
blocked_value_matrix
[:,
:,
0
],
blocked_value_matrix
[:,
:,
1
],
blocked_value_matrix
[:,
:,
2
],
blocked_value_matrix
[:,
:,
-
1
],
gathered_value
[:,
:,
0
]
],
2
)
# [b, h, (4+r)*wn, -1]
second_product
=
tf
.
einsum
(
"BHQD,BHKD->BHQK"
,
blocked_query_matrix
[:,
:,
1
],
second_key_mat
)
# [b, h, wm, -1] x [b, h, (4+r)*wn, -1] ==> [b, h, wm, (4+r)*wn]
second_seq_pad
=
tf
.
concat
([
to_mask
[:,
:,
:,
:
3
*
wn
],
to_mask
[:,
:,
:,
-
wn
:],
tf
.
ones
([
b
,
1
,
1
,
r
*
wn
],
dtype
=
tf
.
float32
)
],
3
)
second_rand_pad
=
tf
.
concat
(
[
tf
.
ones
([
b
,
h
,
wm
,
4
*
wn
],
dtype
=
tf
.
float32
),
rand_mask
[:,
:,
0
]],
3
)
second_product
=
tf
.
multiply
(
second_product
,
1.0
/
np
.
sqrt
(
d
))
second_product
+=
(
1.0
-
tf
.
minimum
(
second_seq_pad
,
second_rand_pad
))
*
-
10000.0
second_attn_weights
=
tf
.
nn
.
softmax
(
second_product
)
# [b , h, wm, (4+r)*wn]
second_context_layer
=
tf
.
einsum
(
"BHQK,BHKD->BHQD"
,
second_attn_weights
,
second_value_mat
)
# [b, h, wm, (4+r)*wn] x [b, h, (4+r)*wn, -1] ==> [b, h, wm, -1]
second_context_layer
=
tf
.
expand_dims
(
second_context_layer
,
2
)
exp_blocked_key_matrix
=
tf
.
concat
([
blocked_key_matrix
[:,
:,
1
:
-
3
],
blocked_key_matrix
[:,
:,
2
:
-
2
],
blocked_key_matrix
[:,
:,
3
:
-
1
]
],
3
)
# [b, h, m//wm-4, 3*wn, -1]
exp_blocked_value_matrix
=
tf
.
concat
([
blocked_value_matrix
[:,
:,
1
:
-
3
],
blocked_value_matrix
[:,
:,
2
:
-
2
],
blocked_value_matrix
[:,
:,
3
:
-
1
]
],
3
)
# [b, h, m//wm-4, 3*wn, -1]
middle_query_matrix
=
blocked_query_matrix
[:,
:,
2
:
-
2
]
inner_band_product
=
tf
.
einsum
(
"BHLQD,BHLKD->BHLQK"
,
middle_query_matrix
,
exp_blocked_key_matrix
)
# [b, h, m//wm-4, wm, -1] x [b, h, m//wm-4, 3*wn, -1]
# ==> [b, h, m//wm-4, wm, 3*wn]
inner_band_product
=
tf
.
multiply
(
inner_band_product
,
1.0
/
np
.
sqrt
(
d
))
rand_band_product
=
tf
.
einsum
(
"BHLQD,BHLKD->BHLQK"
,
middle_query_matrix
,
gathered_key
[:,
:,
1
:
-
1
])
# [b, h, m//wm-4, wm, -1] x [b, h, m//wm-4, r*wn, -1]
# ==> [b, h, m//wm-4, wm, r*wn]
rand_band_product
=
tf
.
multiply
(
rand_band_product
,
1.0
/
np
.
sqrt
(
d
))
first_band_product
=
tf
.
einsum
(
"BHLQD,BHKD->BHLQK"
,
middle_query_matrix
,
blocked_key_matrix
[:,
:,
0
]
)
# [b, h, m//wm-4, wm, -1] x [b, h, wn, -1] ==> [b, h, m//wm-4, wm, wn]
first_band_product
=
tf
.
multiply
(
first_band_product
,
1.0
/
np
.
sqrt
(
d
))
last_band_product
=
tf
.
einsum
(
"BHLQD,BHKD->BHLQK"
,
middle_query_matrix
,
blocked_key_matrix
[:,
:,
-
1
]
)
# [b, h, m//wm-4, wm, -1] x [b, h, wn, -1] ==> [b, h, m//wm-4, wm, wn]
last_band_product
=
tf
.
multiply
(
last_band_product
,
1.0
/
np
.
sqrt
(
d
))
inner_band_product
+=
(
1.0
-
band_mask
)
*
-
10000.0
first_band_product
+=
(
1.0
-
tf
.
expand_dims
(
to_mask
[:,
:,
:,
:
wn
],
3
))
*
-
10000.0
last_band_product
+=
(
1.0
-
tf
.
expand_dims
(
to_mask
[:,
:,
:,
-
wn
:],
3
))
*
-
10000.0
rand_band_product
+=
(
1.0
-
rand_mask
[:,
:,
1
:
-
1
])
*
-
10000.0
band_product
=
tf
.
concat
([
first_band_product
,
inner_band_product
,
rand_band_product
,
last_band_product
],
-
1
)
# [b, h, m//wm-4, wm, (5+r)*wn]
attn_weights
=
tf
.
nn
.
softmax
(
band_product
)
# [b, h, m//wm-4, wm, (5+r)*wn]
context_layer
=
tf
.
einsum
(
"BHLQK,BHLKD->BHLQD"
,
attn_weights
[:,
:,
:,
:,
wn
:
4
*
wn
],
exp_blocked_value_matrix
)
# [b, h, m//wm-4, wm, 3*wn] x [b, h, m//wm-4, 3*wn, -1]
# ==> [b, h, m//wm-4, wm, -1]
context_layer
+=
tf
.
einsum
(
"BHLQK,BHLKD->BHLQD"
,
attn_weights
[:,
:,
:,
:,
4
*
wn
:
-
wn
],
gathered_value
[:,
:,
1
:
-
1
]
)
# [b, h, m//wm-4, wm, r*wn] x [b, h, m//wm-4, r*wn, -1]
# ==> [b, h, m//wm-4, wm, -1]
context_layer
+=
tf
.
einsum
(
"BHLQK,BHKD->BHLQD"
,
attn_weights
[:,
:,
:,
:,
:
wn
],
blocked_value_matrix
[:,
:,
0
]
)
# [b, h, m//wm-4, wm, wn] x [b, h, wn, -1] ==> [b, h, m//wm-4, wm, -1]
context_layer
+=
tf
.
einsum
(
"BHLQK,BHKD->BHLQD"
,
attn_weights
[:,
:,
:,
:,
-
wn
:],
blocked_value_matrix
[:,
:,
-
1
]
)
# [b, h, m//wm-4, wm, wn] x [b, h, wn, -1] ==> [b, h, m//wm-4, wm, -1]
second_last_key_mat
=
tf
.
concat
([
blocked_key_matrix
[:,
:,
0
],
blocked_key_matrix
[:,
:,
-
3
],
blocked_key_matrix
[:,
:,
-
2
],
blocked_key_matrix
[:,
:,
-
1
],
gathered_key
[:,
:,
-
1
]
],
2
)
# [b, h, (4+r)*wn, -1]
second_last_value_mat
=
tf
.
concat
([
blocked_value_matrix
[:,
:,
0
],
blocked_value_matrix
[:,
:,
-
3
],
blocked_value_matrix
[:,
:,
-
2
],
blocked_value_matrix
[:,
:,
-
1
],
gathered_value
[:,
:,
-
1
]
],
2
)
# [b, h, (4+r)*wn, -1]
second_last_product
=
tf
.
einsum
(
"BHQD,BHKD->BHQK"
,
blocked_query_matrix
[:,
:,
-
2
],
second_last_key_mat
)
# [b, h, wm, -1] x [b, h, (4+r)*wn, -1] ==> [b, h, wm, (4+r)*wn]
second_last_seq_pad
=
tf
.
concat
([
to_mask
[:,
:,
:,
:
wn
],
to_mask
[:,
:,
:,
-
3
*
wn
:],
tf
.
ones
([
b
,
1
,
1
,
r
*
wn
],
dtype
=
tf
.
float32
)
],
3
)
second_last_rand_pad
=
tf
.
concat
(
[
tf
.
ones
([
b
,
h
,
wm
,
4
*
wn
],
dtype
=
tf
.
float32
),
rand_mask
[:,
:,
-
1
]],
3
)
second_last_product
=
tf
.
multiply
(
second_last_product
,
1.0
/
np
.
sqrt
(
d
))
second_last_product
+=
(
1.0
-
tf
.
minimum
(
second_last_seq_pad
,
second_last_rand_pad
))
*
-
10000.0
second_last_attn_weights
=
tf
.
nn
.
softmax
(
second_last_product
)
# [b, h, wm, (4+r)*wn]
second_last_context_layer
=
tf
.
einsum
(
"BHQK,BHKD->BHQD"
,
second_last_attn_weights
,
second_last_value_mat
)
# [b, h, wm, (4+r)*wn] x [b, h, (4+r)*wn, -1] ==> [b, h, wm, -1]
second_last_context_layer
=
tf
.
expand_dims
(
second_last_context_layer
,
2
)
last_product
=
tf
.
einsum
(
"BHQD,BHKD->BHQK"
,
blocked_query_matrix
[:,
:,
-
1
],
key_layer
)
# [b, h, wm, -1] x [b, h, n, -1] ==> [b, h, wm, n]
last_product
=
tf
.
multiply
(
last_product
,
1.0
/
np
.
sqrt
(
d
))
last_product
+=
(
1.0
-
to_mask
)
*
-
10000.0
last_attn_weights
=
tf
.
nn
.
softmax
(
last_product
)
# [b, h, wm, n]
last_context_layer
=
tf
.
einsum
(
"BHQK,BHKD->BHQD"
,
last_attn_weights
,
value_layer
)
# [b, h, wm, n] x [b, h, n, -1] ==> [b, h, wm, -1]
last_context_layer
=
tf
.
expand_dims
(
last_context_layer
,
2
)
context_layer
=
tf
.
concat
([
first_context_layer
,
second_context_layer
,
context_layer
,
second_last_context_layer
,
last_context_layer
],
2
)
context_layer
=
tf
.
reshape
(
context_layer
,
(
b
,
h
,
m
,
-
1
))
*
from_mask
context_layer
=
tf
.
transpose
(
context_layer
,
(
0
,
2
,
1
,
3
))
return
context_layer
class
BigBirdMasks
(
tf
.
keras
.
layers
.
Layer
):
"""Creates bigbird attention masks."""
def
__init__
(
self
,
block_size
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
_block_size
=
block_size
def
call
(
self
,
inputs
):
encoder_shape
=
tf
.
shape
(
inputs
)
batch_size
,
seq_length
=
encoder_shape
[
0
],
encoder_shape
[
1
]
# reshape and cast for blocking
inputs
=
tf
.
cast
(
inputs
,
dtype
=
tf
.
float32
)
blocked_encoder_mask
=
tf
.
reshape
(
inputs
,
(
batch_size
,
seq_length
//
self
.
_block_size
,
self
.
_block_size
))
encoder_from_mask
=
tf
.
reshape
(
inputs
,
(
batch_size
,
1
,
seq_length
,
1
))
encoder_to_mask
=
tf
.
reshape
(
inputs
,
(
batch_size
,
1
,
1
,
seq_length
))
band_mask
=
create_band_mask_from_inputs
(
blocked_encoder_mask
,
blocked_encoder_mask
)
return
[
band_mask
,
encoder_from_mask
,
encoder_to_mask
,
blocked_encoder_mask
]
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
class
BigBirdAttention
(
tf
.
keras
.
layers
.
MultiHeadAttention
):
"""BigBird, a sparse attention mechanism.
This layer follows the paper "Big Bird: Transformers for Longer Sequences"
(https://arxiv.org/abs/2007.14062).
It reduces this quadratic dependency of attention
computation to linear.
Arguments are the same as `MultiHeadAttention` layer.
"""
def
__init__
(
self
,
num_rand_blocks
=
3
,
from_block_size
=
64
,
to_block_size
=
64
,
max_rand_mask_length
=
MAX_SEQ_LEN
,
seed
=
None
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
_num_rand_blocks
=
num_rand_blocks
self
.
_from_block_size
=
from_block_size
self
.
_to_block_size
=
to_block_size
self
.
_seed
=
seed
# Generates random attention.
np
.
random
.
seed
(
self
.
_seed
)
# pylint: disable=g-complex-comprehension
rand_attn
=
[
bigbird_block_rand_mask
(
max_rand_mask_length
,
max_rand_mask_length
,
from_block_size
,
to_block_size
,
num_rand_blocks
,
last_idx
=
1024
)
for
_
in
range
(
self
.
_num_heads
)
]
# pylint: enable=g-complex-comprehension
rand_attn
=
np
.
stack
(
rand_attn
,
axis
=
0
)
self
.
rand_attn
=
tf
.
constant
(
rand_attn
,
dtype
=
tf
.
int32
)
def
_compute_attention
(
self
,
query
,
key
,
value
,
attention_mask
=
None
):
(
band_mask
,
encoder_from_mask
,
encoder_to_mask
,
blocked_encoder_mask
)
=
attention_mask
query_shape
=
tf
.
shape
(
query
)
from_seq_length
=
query_shape
[
1
]
to_seq_length
=
tf
.
shape
(
key
)[
1
]
rand_attn
=
self
.
rand_attn
[:,
:(
from_seq_length
//
self
.
_from_block_size
-
2
)]
return
bigbird_block_sparse_attention
(
query
,
key
,
value
,
band_mask
,
encoder_from_mask
,
encoder_to_mask
,
blocked_encoder_mask
,
blocked_encoder_mask
,
num_attention_heads
=
self
.
_num_heads
,
num_rand_blocks
=
self
.
_num_rand_blocks
,
size_per_head
=
self
.
_key_dim
,
batch_size
=
query_shape
[
0
],
from_seq_length
=
from_seq_length
,
to_seq_length
=
to_seq_length
,
from_block_size
=
self
.
_from_block_size
,
to_block_size
=
self
.
_to_block_size
,
rand_attn
=
rand_attn
)
def
call
(
self
,
query
,
value
,
key
=
None
,
attention_mask
=
None
,
**
kwargs
):
if
not
self
.
_built_from_signature
:
self
.
_build_from_signature
(
query
=
query
,
value
=
value
,
key
=
key
)
if
key
is
None
:
key
=
value
# N = `num_attention_heads`
# H = `size_per_head`
# `query` = [B, T, N ,H]
query
=
self
.
_query_dense
(
query
)
# `key` = [B, S, N, H]
key
=
self
.
_key_dense
(
key
)
# `value` = [B, S, N, H]
value
=
self
.
_value_dense
(
value
)
attention_output
=
self
.
_compute_attention
(
query
,
key
,
value
,
attention_mask
)
attention_output
.
set_shape
([
None
,
None
,
self
.
_num_heads
,
self
.
_key_dim
])
attention_output
=
self
.
_output_dense
(
attention_output
)
return
attention_output
def
get_config
(
self
):
config
=
{
"num_rand_blocks"
:
self
.
_num_rand_blocks
,
"from_block_size"
:
self
.
_from_block_size
,
"to_block_size"
:
self
.
_to_block_size
,
"seed"
:
self
.
_seed
}
base_config
=
super
().
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
official/
vision/beta/tasks/image_classifica
tion_test.py
→
official/
nlp/projects/bigbird/atten
tion_test.py
View file @
b0ccdb11
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -13,46 +12,55 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for
image classification task
."""
"""Tests for
official.nlp.projects.bigbird.attention
."""
# pylint: disable=unused-import
from
absl.testing
import
parameterized
import
orbit
import
tensorflow
as
tf
from
official.core
import
exp_factory
from
official.modeling
import
optimization
from
official.vision
import
beta
from
official.vision.beta.tasks
import
image_classification
as
img_cls_task
class
ImageClassificationTaskTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
parameters
((
'resnet_imagenet'
),
(
'revnet_imagenet'
))
def
test_task
(
self
,
config_name
):
config
=
exp_factory
.
get_exp_config
(
config_name
)
config
.
task
.
train_data
.
global_batch_size
=
2
task
=
img_cls_task
.
ImageClassificationTask
(
config
.
task
)
model
=
task
.
build_model
()
metrics
=
task
.
build_metrics
()
strategy
=
tf
.
distribute
.
get_strategy
()
dataset
=
orbit
.
utils
.
make_distributed_dataset
(
strategy
,
task
.
build_inputs
,
config
.
task
.
train_data
)
iterator
=
iter
(
dataset
)
opt_factory
=
optimization
.
OptimizerFactory
(
config
.
trainer
.
optimizer_config
)
optimizer
=
opt_factory
.
build_optimizer
(
opt_factory
.
build_learning_rate
())
logs
=
task
.
train_step
(
next
(
iterator
),
model
,
optimizer
,
metrics
=
metrics
)
self
.
assertIn
(
'loss'
,
logs
)
self
.
assertIn
(
'accuracy'
,
logs
)
self
.
assertIn
(
'top_5_accuracy'
,
logs
)
logs
=
task
.
validation_step
(
next
(
iterator
),
model
,
metrics
=
metrics
)
self
.
assertIn
(
'loss'
,
logs
)
self
.
assertIn
(
'accuracy'
,
logs
)
self
.
assertIn
(
'top_5_accuracy'
,
logs
)
from
official.nlp.projects.bigbird
import
attention
class
BigbirdAttentionTest
(
tf
.
test
.
TestCase
):
def
test_attention
(
self
):
num_heads
=
12
key_dim
=
64
seq_length
=
1024
batch_size
=
2
block_size
=
64
mask_layer
=
attention
.
BigBirdMasks
(
block_size
=
block_size
)
encoder_inputs_mask
=
tf
.
zeros
((
batch_size
,
seq_length
),
dtype
=
tf
.
int32
)
masks
=
mask_layer
(
encoder_inputs_mask
)
test_layer
=
attention
.
BigBirdAttention
(
num_heads
=
num_heads
,
key_dim
=
key_dim
,
from_block_size
=
block_size
,
to_block_size
=
block_size
,
seed
=
0
)
query
=
tf
.
random
.
normal
(
shape
=
(
batch_size
,
seq_length
,
key_dim
))
value
=
query
output
=
test_layer
(
query
=
query
,
value
=
value
,
attention_mask
=
masks
)
self
.
assertEqual
(
output
.
shape
,
[
batch_size
,
seq_length
,
key_dim
])
def
test_config
(
self
):
num_heads
=
12
key_dim
=
64
block_size
=
64
test_layer
=
attention
.
BigBirdAttention
(
num_heads
=
num_heads
,
key_dim
=
key_dim
,
from_block_size
=
block_size
,
to_block_size
=
block_size
,
seed
=
0
)
print
(
test_layer
.
get_config
())
new_layer
=
attention
.
BigBirdAttention
.
from_config
(
test_layer
.
get_config
())
# If the serialization was successful, the new config should match the old.
self
.
assertAllEqual
(
test_layer
.
get_config
(),
new_layer
.
get_config
())
if
__name__
==
'__main__'
:
...
...
official/nlp/projects/bigbird/encoder.py
0 → 100644
View file @
b0ccdb11
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Transformer-based text encoder network."""
# pylint: disable=g-classes-have-attributes
import
tensorflow
as
tf
from
official.modeling
import
activations
from
official.nlp
import
keras_nlp
from
official.nlp.modeling
import
layers
from
official.nlp.projects.bigbird
import
attention
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
class
BigBirdEncoder
(
tf
.
keras
.
Model
):
"""Transformer-based encoder network with BigBird attentions.
*Note* that the network is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments:
vocab_size: The size of the token vocabulary.
hidden_size: The size of the transformer hidden layers.
num_layers: The number of transformer layers.
num_attention_heads: The number of attention heads for each transformer. The
hidden size must be divisible by the number of attention heads.
max_sequence_length: The maximum sequence length that this encoder can
consume. If None, max_sequence_length uses the value from sequence length.
This determines the variable shape for positional embeddings.
type_vocab_size: The number of types that the 'type_ids' input can take.
intermediate_size: The intermediate size for the transformer layers.
activation: The activation to use for the transformer layers.
dropout_rate: The dropout rate to use for the transformer layers.
attention_dropout_rate: The dropout rate to use for the attention layers
within the transformer layers.
initializer: The initialzer to use for all weights in this encoder.
embedding_width: The width of the word embeddings. If the embedding width is
not equal to hidden size, embedding parameters will be factorized into two
matrices in the shape of ['vocab_size', 'embedding_width'] and
['embedding_width', 'hidden_size'] ('embedding_width' is usually much
smaller than 'hidden_size').
"""
def
__init__
(
self
,
vocab_size
,
hidden_size
=
768
,
num_layers
=
12
,
num_attention_heads
=
12
,
max_sequence_length
=
attention
.
MAX_SEQ_LEN
,
type_vocab_size
=
16
,
intermediate_size
=
3072
,
block_size
=
64
,
num_rand_blocks
=
3
,
activation
=
activations
.
gelu
,
dropout_rate
=
0.1
,
attention_dropout_rate
=
0.1
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
),
embedding_width
=
None
,
**
kwargs
):
activation
=
tf
.
keras
.
activations
.
get
(
activation
)
initializer
=
tf
.
keras
.
initializers
.
get
(
initializer
)
self
.
_self_setattr_tracking
=
False
self
.
_config_dict
=
{
'vocab_size'
:
vocab_size
,
'hidden_size'
:
hidden_size
,
'num_layers'
:
num_layers
,
'num_attention_heads'
:
num_attention_heads
,
'max_sequence_length'
:
max_sequence_length
,
'type_vocab_size'
:
type_vocab_size
,
'intermediate_size'
:
intermediate_size
,
'block_size'
:
block_size
,
'num_rand_blocks'
:
num_rand_blocks
,
'activation'
:
tf
.
keras
.
activations
.
serialize
(
activation
),
'dropout_rate'
:
dropout_rate
,
'attention_dropout_rate'
:
attention_dropout_rate
,
'initializer'
:
tf
.
keras
.
initializers
.
serialize
(
initializer
),
'embedding_width'
:
embedding_width
,
}
word_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
,
name
=
'input_word_ids'
)
mask
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
,
name
=
'input_mask'
)
type_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
,
name
=
'input_type_ids'
)
if
embedding_width
is
None
:
embedding_width
=
hidden_size
self
.
_embedding_layer
=
keras_nlp
.
layers
.
OnDeviceEmbedding
(
vocab_size
=
vocab_size
,
embedding_width
=
embedding_width
,
initializer
=
initializer
,
name
=
'word_embeddings'
)
word_embeddings
=
self
.
_embedding_layer
(
word_ids
)
# Always uses dynamic slicing for simplicity.
self
.
_position_embedding_layer
=
keras_nlp
.
layers
.
PositionEmbedding
(
initializer
=
initializer
,
max_length
=
max_sequence_length
,
name
=
'position_embedding'
)
position_embeddings
=
self
.
_position_embedding_layer
(
word_embeddings
)
self
.
_type_embedding_layer
=
keras_nlp
.
layers
.
OnDeviceEmbedding
(
vocab_size
=
type_vocab_size
,
embedding_width
=
embedding_width
,
initializer
=
initializer
,
use_one_hot
=
True
,
name
=
'type_embeddings'
)
type_embeddings
=
self
.
_type_embedding_layer
(
type_ids
)
embeddings
=
tf
.
keras
.
layers
.
Add
()(
[
word_embeddings
,
position_embeddings
,
type_embeddings
])
self
.
_embedding_norm_layer
=
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
'embeddings/layer_norm'
,
axis
=-
1
,
epsilon
=
1e-12
,
dtype
=
tf
.
float32
)
embeddings
=
self
.
_embedding_norm_layer
(
embeddings
)
embeddings
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
dropout_rate
)(
embeddings
)
# We project the 'embedding' output to 'hidden_size' if it is not already
# 'hidden_size'.
if
embedding_width
!=
hidden_size
:
self
.
_embedding_projection
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
'...x,xy->...y'
,
output_shape
=
hidden_size
,
bias_axes
=
'y'
,
kernel_initializer
=
initializer
,
name
=
'embedding_projection'
)
embeddings
=
self
.
_embedding_projection
(
embeddings
)
self
.
_transformer_layers
=
[]
data
=
embeddings
masks
=
attention
.
BigBirdMasks
(
block_size
=
block_size
)(
mask
)
encoder_outputs
=
[]
attn_head_dim
=
hidden_size
//
num_attention_heads
for
i
in
range
(
num_layers
):
layer
=
layers
.
TransformerScaffold
(
num_attention_heads
,
intermediate_size
,
activation
,
attention_cls
=
attention
.
BigBirdAttention
,
attention_cfg
=
dict
(
num_heads
=
num_attention_heads
,
key_dim
=
attn_head_dim
,
kernel_initializer
=
initializer
,
from_block_size
=
block_size
,
to_block_size
=
block_size
,
num_rand_blocks
=
num_rand_blocks
,
max_rand_mask_length
=
max_sequence_length
,
seed
=
i
),
dropout_rate
=
dropout_rate
,
attention_dropout_rate
=
dropout_rate
,
kernel_initializer
=
initializer
)
self
.
_transformer_layers
.
append
(
layer
)
data
=
layer
([
data
,
masks
])
encoder_outputs
.
append
(
data
)
outputs
=
dict
(
sequence_output
=
encoder_outputs
[
-
1
],
encoder_outputs
=
encoder_outputs
)
super
().
__init__
(
inputs
=
[
word_ids
,
mask
,
type_ids
],
outputs
=
outputs
,
**
kwargs
)
def
get_embedding_table
(
self
):
return
self
.
_embedding_layer
.
embeddings
def
get_embedding_layer
(
self
):
return
self
.
_embedding_layer
def
get_config
(
self
):
return
self
.
_config_dict
@
property
def
transformer_layers
(
self
):
"""List of Transformer layers in the encoder."""
return
self
.
_transformer_layers
@
property
def
pooler_layer
(
self
):
"""The pooler dense layer after the transformer layers."""
return
self
.
_pooler_layer
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
official/
vision/beta/modeling/layers/roi_sampl
er_test.py
→
official/
nlp/projects/bigbird/encod
er_test.py
View file @
b0ccdb11
...
...
@@ -12,65 +12,52 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for
roi_sampler.py
."""
"""Tests for
official.nlp.projects.bigbird.encoder
."""
# Import libraries
import
numpy
as
np
import
tensorflow
as
tf
from
official.vision.beta.modeling.layers
import
roi_sampler
class
ROISamplerTest
(
tf
.
test
.
TestCase
):
def
test_roi_sampler
(
self
):
boxes_np
=
np
.
array
(
[[[
0
,
0
,
5
,
5
],
[
2.5
,
2.5
,
7.5
,
7.5
],
[
5
,
5
,
10
,
10
],
[
7.5
,
7.5
,
12.5
,
12.5
]]])
boxes
=
tf
.
constant
(
boxes_np
,
dtype
=
tf
.
float32
)
gt_boxes_np
=
np
.
array
(
[[[
10
,
10
,
15
,
15
],
[
2.5
,
2.5
,
7.5
,
7.5
],
[
-
1
,
-
1
,
-
1
,
-
1
]]])
gt_boxes
=
tf
.
constant
(
gt_boxes_np
,
dtype
=
tf
.
float32
)
gt_classes_np
=
np
.
array
([[
2
,
10
,
-
1
]])
gt_classes
=
tf
.
constant
(
gt_classes_np
,
dtype
=
tf
.
int32
)
generator
=
roi_sampler
.
ROISampler
(
mix_gt_boxes
=
True
,
num_sampled_rois
=
2
,
foreground_fraction
=
0.5
,
foreground_iou_threshold
=
0.5
,
background_iou_high_threshold
=
0.5
,
background_iou_low_threshold
=
0.0
)
# Runs on TPU.
strategy
=
tf
.
distribute
.
experimental
.
TPUStrategy
()
with
strategy
.
scope
():
_
=
generator
(
boxes
,
gt_boxes
,
gt_classes
)
# Runs on CPU.
_
=
generator
(
boxes
,
gt_boxes
,
gt_classes
)
def
test_serialize_deserialize
(
self
):
kwargs
=
dict
(
mix_gt_boxes
=
True
,
num_sampled_rois
=
512
,
foreground_fraction
=
0.25
,
foreground_iou_threshold
=
0.5
,
background_iou_high_threshold
=
0.5
,
background_iou_low_threshold
=
0.5
,
)
generator
=
roi_sampler
.
ROISampler
(
**
kwargs
)
expected_config
=
dict
(
kwargs
)
self
.
assertEqual
(
generator
.
get_config
(),
expected_config
)
new_generator
=
roi_sampler
.
ROISampler
.
from_config
(
generator
.
get_config
())
self
.
assertAllEqual
(
generator
.
get_config
(),
new_generator
.
get_config
())
if
__name__
==
'__main__'
:
from
official.nlp.projects.bigbird
import
encoder
class
BigBirdEncoderTest
(
tf
.
test
.
TestCase
):
def
test_encoder
(
self
):
sequence_length
=
1024
batch_size
=
2
vocab_size
=
1024
network
=
encoder
.
BigBirdEncoder
(
num_layers
=
1
,
vocab_size
=
1024
,
max_sequence_length
=
4096
)
word_id_data
=
np
.
random
.
randint
(
vocab_size
,
size
=
(
batch_size
,
sequence_length
))
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
))
type_id_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
))
outputs
=
network
([
word_id_data
,
mask_data
,
type_id_data
])
self
.
assertEqual
(
outputs
[
"sequence_output"
].
shape
,
(
batch_size
,
sequence_length
,
768
))
def
test_save_restore
(
self
):
sequence_length
=
1024
batch_size
=
2
vocab_size
=
1024
network
=
encoder
.
BigBirdEncoder
(
num_layers
=
1
,
vocab_size
=
1024
,
max_sequence_length
=
4096
)
word_id_data
=
np
.
random
.
randint
(
vocab_size
,
size
=
(
batch_size
,
sequence_length
))
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
))
type_id_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
))
inputs
=
dict
(
input_word_ids
=
word_id_data
,
input_mask
=
mask_data
,
input_type_ids
=
type_id_data
)
ref_outputs
=
network
(
inputs
)
model_path
=
self
.
get_temp_dir
()
+
"/model"
network
.
save
(
model_path
)
loaded
=
tf
.
keras
.
models
.
load_model
(
model_path
)
outputs
=
loaded
(
inputs
)
self
.
assertAllClose
(
outputs
[
"sequence_output"
],
ref_outputs
[
"sequence_output"
])
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/nlp/tasks/masked_lm.py
View file @
b0ccdb11
...
...
@@ -63,31 +63,33 @@ class MaskedLMTask(base_task.Task):
model_outputs
,
metrics
,
aux_losses
=
None
)
->
tf
.
Tensor
:
metrics
=
dict
([(
metric
.
name
,
metric
)
for
metric
in
metrics
])
lm_prediction_losses
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
labels
[
'masked_lm_ids'
],
tf
.
cast
(
model_outputs
[
'lm_output'
],
tf
.
float32
),
from_logits
=
True
)
lm_label_weights
=
labels
[
'masked_lm_weights'
]
lm_numerator_loss
=
tf
.
reduce_sum
(
lm_prediction_losses
*
lm_label_weights
)
lm_denominator_loss
=
tf
.
reduce_sum
(
lm_label_weights
)
mlm_loss
=
tf
.
math
.
divide_no_nan
(
lm_numerator_loss
,
lm_denominator_loss
)
metrics
[
'lm_example_loss'
].
update_state
(
mlm_loss
)
if
'next_sentence_labels'
in
labels
:
sentence_labels
=
labels
[
'next_sentence_labels'
]
sentence_outputs
=
tf
.
cast
(
model_outputs
[
'next_sentence'
],
dtype
=
tf
.
float32
)
sentence_loss
=
tf
.
reduce_mean
(
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
sentence_labels
,
sentence_outputs
,
from_logits
=
True
))
metrics
[
'next_sentence_loss'
].
update_state
(
sentence_loss
)
total_loss
=
mlm_loss
+
sentence_loss
else
:
total_loss
=
mlm_loss
if
aux_losses
:
total_loss
+=
tf
.
add_n
(
aux_losses
)
return
total_loss
with
tf
.
name_scope
(
'MaskedLMTask/losses'
):
metrics
=
dict
([(
metric
.
name
,
metric
)
for
metric
in
metrics
])
lm_prediction_losses
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
labels
[
'masked_lm_ids'
],
tf
.
cast
(
model_outputs
[
'mlm_logits'
],
tf
.
float32
),
from_logits
=
True
)
lm_label_weights
=
labels
[
'masked_lm_weights'
]
lm_numerator_loss
=
tf
.
reduce_sum
(
lm_prediction_losses
*
lm_label_weights
)
lm_denominator_loss
=
tf
.
reduce_sum
(
lm_label_weights
)
mlm_loss
=
tf
.
math
.
divide_no_nan
(
lm_numerator_loss
,
lm_denominator_loss
)
metrics
[
'lm_example_loss'
].
update_state
(
mlm_loss
)
if
'next_sentence_labels'
in
labels
:
sentence_labels
=
labels
[
'next_sentence_labels'
]
sentence_outputs
=
tf
.
cast
(
model_outputs
[
'next_sentence'
],
dtype
=
tf
.
float32
)
sentence_loss
=
tf
.
reduce_mean
(
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
sentence_labels
,
sentence_outputs
,
from_logits
=
True
))
metrics
[
'next_sentence_loss'
].
update_state
(
sentence_loss
)
total_loss
=
mlm_loss
+
sentence_loss
else
:
total_loss
=
mlm_loss
if
aux_losses
:
total_loss
+=
tf
.
add_n
(
aux_losses
)
return
total_loss
def
build_inputs
(
self
,
params
,
input_context
=
None
):
"""Returns tf.data.Dataset for pretraining."""
...
...
@@ -128,14 +130,15 @@ class MaskedLMTask(base_task.Task):
return
metrics
def
process_metrics
(
self
,
metrics
,
labels
,
model_outputs
):
metrics
=
dict
([(
metric
.
name
,
metric
)
for
metric
in
metrics
])
if
'masked_lm_accuracy'
in
metrics
:
metrics
[
'masked_lm_accuracy'
].
update_state
(
labels
[
'masked_lm_ids'
],
model_outputs
[
'lm_output'
],
labels
[
'masked_lm_weights'
])
if
'next_sentence_accuracy'
in
metrics
:
metrics
[
'next_sentence_accuracy'
].
update_state
(
labels
[
'next_sentence_labels'
],
model_outputs
[
'next_sentence'
])
with
tf
.
name_scope
(
'MaskedLMTask/process_metrics'
):
metrics
=
dict
([(
metric
.
name
,
metric
)
for
metric
in
metrics
])
if
'masked_lm_accuracy'
in
metrics
:
metrics
[
'masked_lm_accuracy'
].
update_state
(
labels
[
'masked_lm_ids'
],
model_outputs
[
'mlm_logits'
],
labels
[
'masked_lm_weights'
])
if
'next_sentence_accuracy'
in
metrics
:
metrics
[
'next_sentence_accuracy'
].
update_state
(
labels
[
'next_sentence_labels'
],
model_outputs
[
'next_sentence'
])
def
train_step
(
self
,
inputs
,
model
:
tf
.
keras
.
Model
,
optimizer
:
tf
.
keras
.
optimizers
.
Optimizer
,
metrics
):
...
...
official/nlp/tasks/utils.py
View file @
b0ccdb11
...
...
@@ -16,25 +16,51 @@
"""Common utils for tasks."""
from
typing
import
Any
,
Callable
from
absl
import
logging
import
orbit
import
tensorflow
as
tf
import
tensorflow_hub
as
hub
def
get_encoder_from_hub
(
hub_module
:
str
)
->
tf
.
keras
.
Model
:
"""Gets an encoder from hub."""
def
get_encoder_from_hub
(
hub_model
)
->
tf
.
keras
.
Model
:
"""Gets an encoder from hub.
Args:
hub_model: A tfhub model loaded by `hub.load(...)`.
Returns:
A tf.keras.Model.
"""
input_word_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
,
name
=
'input_word_ids'
)
input_mask
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
,
name
=
'input_mask'
)
input_type_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
,
name
=
'input_type_ids'
)
hub_layer
=
hub
.
KerasLayer
(
hub_module
,
trainable
=
True
)
pooled_output
,
sequence_output
=
hub_layer
(
[
input_word_ids
,
input_mask
,
input_type_ids
])
return
tf
.
keras
.
Model
(
inputs
=
[
input_word_ids
,
input_mask
,
input_type_ids
],
outputs
=
[
sequence_output
,
pooled_output
])
hub_layer
=
hub
.
KerasLayer
(
hub_model
,
trainable
=
True
)
output_dict
=
{}
dict_input
=
dict
(
input_word_ids
=
input_word_ids
,
input_mask
=
input_mask
,
input_type_ids
=
input_type_ids
)
# The legacy hub model takes a list as input and returns a Tuple of
# `pooled_output` and `sequence_output`, while the new hub model takes dict
# as input and returns a dict.
# TODO(chendouble): Remove the support of legacy hub model when the new ones
# are released.
hub_output_signature
=
hub_model
.
signatures
[
'serving_default'
].
outputs
if
len
(
hub_output_signature
)
==
2
:
logging
.
info
(
'Use the legacy hub module with list as input/output.'
)
pooled_output
,
sequence_output
=
hub_layer
(
[
input_word_ids
,
input_mask
,
input_type_ids
])
output_dict
[
'pooled_output'
]
=
pooled_output
output_dict
[
'sequence_output'
]
=
sequence_output
else
:
logging
.
info
(
'Use the new hub module with dict as input/output.'
)
output_dict
=
hub_layer
(
dict_input
)
return
tf
.
keras
.
Model
(
inputs
=
dict_input
,
outputs
=
output_dict
)
def
predict
(
predict_step_fn
:
Callable
[[
Any
],
Any
],
...
...
official/nlp/train.py
View file @
b0ccdb11
...
...
@@ -20,6 +20,7 @@ from absl import flags
import
gin
from
official.core
import
train_utils
from
official.common
import
distribute_utils
# pylint: disable=unused-import
from
official.common
import
registry_imports
# pylint: enable=unused-import
...
...
@@ -27,7 +28,6 @@ from official.common import flags as tfm_flags
from
official.core
import
task_factory
from
official.core
import
train_lib
from
official.modeling
import
performance
from
official.utils.misc
import
distribution_utils
FLAGS
=
flags
.
FLAGS
...
...
@@ -48,11 +48,12 @@ def main(_):
if
params
.
runtime
.
mixed_precision_dtype
:
performance
.
set_mixed_precision_policy
(
params
.
runtime
.
mixed_precision_dtype
,
params
.
runtime
.
loss_scale
)
distribution_strategy
=
distribut
ion
_utils
.
get_distribution_strategy
(
distribution_strategy
=
distribut
e
_utils
.
get_distribution_strategy
(
distribution_strategy
=
params
.
runtime
.
distribution_strategy
,
all_reduce_alg
=
params
.
runtime
.
all_reduce_alg
,
num_gpus
=
params
.
runtime
.
num_gpus
,
tpu_address
=
params
.
runtime
.
tpu
)
tpu_address
=
params
.
runtime
.
tpu
,
**
params
.
runtime
.
model_parallelism
())
with
distribution_strategy
.
scope
():
task
=
task_factory
.
get_task
(
params
.
task
,
logging_dir
=
model_dir
)
...
...
official/nlp/train_ctl_continuous_finetune.py
View file @
b0ccdb11
...
...
@@ -15,9 +15,10 @@
# ==============================================================================
"""TFM continuous finetuning+eval training driver."""
import
gc
import
os
import
time
from
typing
import
Mapping
,
Any
from
typing
import
Any
,
Mapping
,
Optional
from
absl
import
app
from
absl
import
flags
...
...
@@ -28,38 +29,44 @@ import tensorflow as tf
# pylint: disable=unused-import
from
official.common
import
registry_imports
# pylint: enable=unused-import
from
official.common
import
distribute_utils
from
official.common
import
flags
as
tfm_flags
from
official.core
import
task_factory
from
official.core
import
train_lib
from
official.core
import
train_utils
from
official.modeling
import
performance
from
official.modeling.hyperparams
import
config_definitions
from
official.utils.misc
import
distribution_utils
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_integer
(
'pretrain_steps'
,
default
=
None
,
help
=
'The number of total training steps for the pretraining job.'
)
def
run_continuous_finetune
(
mode
:
str
,
params
:
config_definitions
.
ExperimentConfig
,
model_dir
:
str
,
run_post_eval
:
bool
=
False
,
pretrain_steps
:
Optional
[
int
]
=
None
,
)
->
Mapping
[
str
,
Any
]:
"""Run modes with continuous training.
Currently only supports continuous_train_and_eval.
Args:
mode: A 'str', specifying the mode.
continuous_train_and_eval - monitors a checkpoint directory. Once a new
checkpoint is discovered, loads the checkpoint, finetune the model by
training it (probably on another dataset or with another task), then
evaluate the finetuned model.
mode: A 'str', specifying the mode. continuous_train_and_eval - monitors a
checkpoint directory. Once a new checkpoint is discovered, loads the
checkpoint, finetune the model by training it (probably on another dataset
or with another task), then evaluate the finetuned model.
params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
run_post_eval: Whether to run post eval once after training, metrics logs
are returned.
pretrain_steps: Optional, the number of total training steps for the
pretraining job.
Returns:
eval logs: returns eval metrics logs when run_post_eval is set to True,
...
...
@@ -77,7 +84,7 @@ def run_continuous_finetune(
if
params
.
runtime
.
mixed_precision_dtype
:
performance
.
set_mixed_precision_policy
(
params
.
runtime
.
mixed_precision_dtype
,
params
.
runtime
.
loss_scale
)
distribution_strategy
=
distribut
ion
_utils
.
get_distribution_strategy
(
distribution_strategy
=
distribut
e
_utils
.
get_distribution_strategy
(
distribution_strategy
=
params
.
runtime
.
distribution_strategy
,
all_reduce_alg
=
params
.
runtime
.
all_reduce_alg
,
num_gpus
=
params
.
runtime
.
num_gpus
,
...
...
@@ -95,10 +102,24 @@ def run_continuous_finetune(
summary_writer
=
tf
.
summary
.
create_file_writer
(
os
.
path
.
join
(
model_dir
,
'eval'
))
global_step
=
0
def
timeout_fn
():
if
pretrain_steps
and
global_step
<
pretrain_steps
:
# Keeps waiting for another timeout period.
logging
.
info
(
'Continue waiting for new checkpoint as current pretrain '
'global_step=%d and target is %d.'
,
global_step
,
pretrain_steps
)
return
False
# Quits the loop.
return
True
for
pretrain_ckpt
in
tf
.
train
.
checkpoints_iterator
(
checkpoint_dir
=
params
.
task
.
init_checkpoint
,
min_interval_secs
=
10
,
timeout
=
params
.
trainer
.
continuous_eval_timeout
):
timeout
=
params
.
trainer
.
continuous_eval_timeout
,
timeout_fn
=
timeout_fn
):
with
distribution_strategy
.
scope
():
global_step
=
train_utils
.
read_global_step_from_checkpoint
(
pretrain_ckpt
)
...
...
@@ -139,6 +160,13 @@ def run_continuous_finetune(
train_utils
.
write_summary
(
summary_writer
,
global_step
,
summaries
)
train_utils
.
remove_ckpts
(
model_dir
)
# In TF2, the resource life cycle is bound with the python object life
# cycle. Force trigger python garbage collection here so those resources
# can be deallocated in time, so it doesn't cause OOM when allocating new
# objects.
# TODO(b/169178664): Fix cycle reference in Keras model and revisit to see
# if we need gc here.
gc
.
collect
()
if
run_post_eval
:
return
eval_metrics
...
...
@@ -150,7 +178,7 @@ def main(_):
params
=
train_utils
.
parse_configuration
(
FLAGS
)
model_dir
=
FLAGS
.
model_dir
train_utils
.
serialize_config
(
params
,
model_dir
)
run_continuous_finetune
(
FLAGS
.
mode
,
params
,
model_dir
)
run_continuous_finetune
(
FLAGS
.
mode
,
params
,
model_dir
,
FLAGS
.
pretrain_steps
)
if
__name__
==
'__main__'
:
...
...
official/nlp/train_ctl_continuous_finetune_test.py
View file @
b0ccdb11
...
...
@@ -15,10 +15,9 @@
# ==============================================================================
import
os
# Import libraries
from
absl
import
flags
from
absl.testing
import
flagsaver
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.common
import
flags
as
tfm_flags
from
official.core
import
task_factory
...
...
@@ -31,14 +30,14 @@ FLAGS = flags.FLAGS
tfm_flags
.
define_flags
()
class
Main
ContinuousFinetuneTest
(
tf
.
test
.
TestCase
):
class
ContinuousFinetuneTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
super
(
MainContinuousFinetuneTest
,
self
).
setUp
()
super
().
setUp
()
self
.
_model_dir
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'model_dir'
)
@
flagsaver
.
flagsaver
def
testTrainCtl
(
self
):
@
parameterized
.
parameters
(
None
,
1
)
def
testTrainCtl
(
self
,
pretrain_steps
):
src_model_dir
=
self
.
get_temp_dir
()
flags_dict
=
dict
(
experiment
=
'mock'
,
...
...
@@ -81,7 +80,11 @@ class MainContinuousFinetuneTest(tf.test.TestCase):
params
=
train_utils
.
parse_configuration
(
FLAGS
)
eval_metrics
=
train_ctl_continuous_finetune
.
run_continuous_finetune
(
FLAGS
.
mode
,
params
,
FLAGS
.
model_dir
,
run_post_eval
=
True
)
FLAGS
.
mode
,
params
,
FLAGS
.
model_dir
,
run_post_eval
=
True
,
pretrain_steps
=
pretrain_steps
)
self
.
assertIn
(
'best_acc'
,
eval_metrics
)
...
...
Prev
1
2
3
4
5
6
7
8
9
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment