Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
91ccdb45
Commit
91ccdb45
authored
Sep 22, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Sep 22, 2020
Browse files
Internal change
PiperOrigin-RevId: 333019996
parent
8d11ee24
Changes
5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
858 additions
and
7 deletions
+858
-7
official/nlp/configs/encoders.py
official/nlp/configs/encoders.py
+47
-7
official/nlp/projects/bigbird/attention.py
official/nlp/projects/bigbird/attention.py
+485
-0
official/nlp/projects/bigbird/attention_test.py
official/nlp/projects/bigbird/attention_test.py
+67
-0
official/nlp/projects/bigbird/encoder.py
official/nlp/projects/bigbird/encoder.py
+196
-0
official/nlp/projects/bigbird/encoder_test.py
official/nlp/projects/bigbird/encoder_test.py
+63
-0
No files found.
official/nlp/configs/encoders.py
View file @
91ccdb45
...
...
@@ -28,6 +28,7 @@ from official.modeling import hyperparams
from
official.modeling
import
tf_utils
from
official.nlp
import
keras_nlp
from
official.nlp.modeling
import
networks
from
official.nlp.projects.bigbird
import
encoder
as
bigbird_encoder
@
dataclasses
.
dataclass
...
...
@@ -60,18 +61,18 @@ class MobileBertEncoderConfig(hyperparams.Config):
num_blocks: number of transformer block in the encoder model.
hidden_size: the hidden size for the transformer block.
num_attention_heads: number of attention heads in the transformer block.
intermediate_size: the size of the "intermediate" (a.k.a., feed
forward)
layer.
intermediate_act_fn: the non-linear activation function to apply
to the
output of the intermediate/feed-forward layer.
intermediate_size: the size of the "intermediate" (a.k.a., feed
forward)
layer.
intermediate_act_fn: the non-linear activation function to apply
to the
output of the intermediate/feed-forward layer.
hidden_dropout_prob: dropout probability for the hidden layers.
attention_probs_dropout_prob: dropout probability of the attention
probabilities.
intra_bottleneck_size: the size of bottleneck.
initializer_range: The stddev of the truncated_normal_initializer for
initializing all weight matrices.
key_query_shared_bottleneck: whether to share linear transformation for
keys
and queries.
initializing all weight matrices.
key_query_shared_bottleneck: whether to share linear transformation for
keys
and queries.
num_feedforward_networks: number of stacked feed-forward networks.
normalization_type: the type of normalization_type, only 'no_norm' and
'layer_norm' are supported. 'no_norm' represents the element-wise linear
...
...
@@ -116,12 +117,32 @@ class AlbertEncoderConfig(hyperparams.Config):
initializer_range
:
float
=
0.02
@
dataclasses
.
dataclass
class
BigBirdEncoderConfig
(
hyperparams
.
Config
):
"""BigBird encoder configuration."""
vocab_size
:
int
=
50358
hidden_size
:
int
=
768
num_layers
:
int
=
12
num_attention_heads
:
int
=
12
hidden_activation
:
str
=
"gelu"
intermediate_size
:
int
=
3072
dropout_rate
:
float
=
0.1
attention_dropout_rate
:
float
=
0.1
max_position_embeddings
:
int
=
4096
num_rand_blocks
:
int
=
3
block_size
:
int
=
64
type_vocab_size
:
int
=
16
initializer_range
:
float
=
0.02
embedding_size
:
Optional
[
int
]
=
None
@
dataclasses
.
dataclass
class
EncoderConfig
(
hyperparams
.
OneOfConfig
):
"""Encoder configuration."""
type
:
Optional
[
str
]
=
"bert"
albert
:
AlbertEncoderConfig
=
AlbertEncoderConfig
()
bert
:
BertEncoderConfig
=
BertEncoderConfig
()
bigbird
:
BigBirdEncoderConfig
=
BigBirdEncoderConfig
()
mobilebert
:
MobileBertEncoderConfig
=
MobileBertEncoderConfig
()
...
...
@@ -129,6 +150,7 @@ ENCODER_CLS = {
"bert"
:
networks
.
BertEncoder
,
"mobilebert"
:
networks
.
MobileBERTEncoder
,
"albert"
:
networks
.
AlbertTransformerEncoder
,
"bigbird"
:
bigbird_encoder
.
BigBirdEncoder
,
}
...
...
@@ -226,6 +248,24 @@ def build_encoder(
stddev
=
encoder_cfg
.
initializer_range
),
dict_outputs
=
True
)
if
encoder_type
==
"bigbird"
:
return
encoder_cls
(
vocab_size
=
encoder_cfg
.
vocab_size
,
hidden_size
=
encoder_cfg
.
hidden_size
,
num_layers
=
encoder_cfg
.
num_layers
,
num_attention_heads
=
encoder_cfg
.
num_attention_heads
,
intermediate_size
=
encoder_cfg
.
intermediate_size
,
activation
=
tf_utils
.
get_activation
(
encoder_cfg
.
hidden_activation
),
dropout_rate
=
encoder_cfg
.
dropout_rate
,
attention_dropout_rate
=
encoder_cfg
.
attention_dropout_rate
,
num_rand_blocks
=
encoder_cfg
.
num_rand_blocks
,
block_size
=
encoder_cfg
.
block_size
,
max_sequence_length
=
encoder_cfg
.
max_position_embeddings
,
type_vocab_size
=
encoder_cfg
.
type_vocab_size
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
encoder_cfg
.
initializer_range
),
embedding_width
=
encoder_cfg
.
embedding_size
)
# Uses the default BERTEncoder configuration schema to create the encoder.
# If it does not match, please add a switch branch by the encoder type.
return
encoder_cls
(
...
...
official/nlp/projects/bigbird/attention.py
0 → 100644
View file @
91ccdb45
This diff is collapsed.
Click to expand it.
official/nlp/projects/bigbird/attention_test.py
0 → 100644
View file @
91ccdb45
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for official.nlp.projects.bigbird.attention."""
import
tensorflow
as
tf
from
official.nlp.projects.bigbird
import
attention
class
BigbirdAttentionTest
(
tf
.
test
.
TestCase
):
def
test_attention
(
self
):
num_heads
=
12
key_dim
=
64
seq_length
=
1024
batch_size
=
2
block_size
=
64
mask_layer
=
attention
.
BigBirdMasks
(
block_size
=
block_size
)
encoder_inputs_mask
=
tf
.
zeros
((
batch_size
,
seq_length
),
dtype
=
tf
.
int32
)
masks
=
mask_layer
(
encoder_inputs_mask
)
test_layer
=
attention
.
BigBirdAttention
(
num_heads
=
num_heads
,
key_dim
=
key_dim
,
from_block_size
=
block_size
,
to_block_size
=
block_size
,
seed
=
0
)
query
=
tf
.
random
.
normal
(
shape
=
(
batch_size
,
seq_length
,
key_dim
))
value
=
query
output
=
test_layer
(
query
=
query
,
value
=
value
,
attention_mask
=
masks
)
self
.
assertEqual
(
output
.
shape
,
[
batch_size
,
seq_length
,
key_dim
])
def
test_config
(
self
):
num_heads
=
12
key_dim
=
64
block_size
=
64
test_layer
=
attention
.
BigBirdAttention
(
num_heads
=
num_heads
,
key_dim
=
key_dim
,
from_block_size
=
block_size
,
to_block_size
=
block_size
,
seed
=
0
)
print
(
test_layer
.
get_config
())
new_layer
=
attention
.
BigBirdAttention
.
from_config
(
test_layer
.
get_config
())
# If the serialization was successful, the new config should match the old.
self
.
assertAllEqual
(
test_layer
.
get_config
(),
new_layer
.
get_config
())
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/nlp/projects/bigbird/encoder.py
0 → 100644
View file @
91ccdb45
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Transformer-based text encoder network."""
# pylint: disable=g-classes-have-attributes
import
tensorflow
as
tf
from
official.modeling
import
activations
from
official.nlp
import
keras_nlp
from
official.nlp.modeling
import
layers
from
official.nlp.projects.bigbird
import
attention
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
class
BigBirdEncoder
(
tf
.
keras
.
Model
):
"""Transformer-based encoder network with BigBird attentions.
*Note* that the network is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments:
vocab_size: The size of the token vocabulary.
hidden_size: The size of the transformer hidden layers.
num_layers: The number of transformer layers.
num_attention_heads: The number of attention heads for each transformer. The
hidden size must be divisible by the number of attention heads.
max_sequence_length: The maximum sequence length that this encoder can
consume. If None, max_sequence_length uses the value from sequence length.
This determines the variable shape for positional embeddings.
type_vocab_size: The number of types that the 'type_ids' input can take.
intermediate_size: The intermediate size for the transformer layers.
activation: The activation to use for the transformer layers.
dropout_rate: The dropout rate to use for the transformer layers.
attention_dropout_rate: The dropout rate to use for the attention layers
within the transformer layers.
initializer: The initialzer to use for all weights in this encoder.
embedding_width: The width of the word embeddings. If the embedding width is
not equal to hidden size, embedding parameters will be factorized into two
matrices in the shape of ['vocab_size', 'embedding_width'] and
['embedding_width', 'hidden_size'] ('embedding_width' is usually much
smaller than 'hidden_size').
"""
def
__init__
(
self
,
vocab_size
,
hidden_size
=
768
,
num_layers
=
12
,
num_attention_heads
=
12
,
max_sequence_length
=
attention
.
MAX_SEQ_LEN
,
type_vocab_size
=
16
,
intermediate_size
=
3072
,
block_size
=
64
,
num_rand_blocks
=
3
,
activation
=
activations
.
gelu
,
dropout_rate
=
0.1
,
attention_dropout_rate
=
0.1
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
),
embedding_width
=
None
,
**
kwargs
):
activation
=
tf
.
keras
.
activations
.
get
(
activation
)
initializer
=
tf
.
keras
.
initializers
.
get
(
initializer
)
self
.
_self_setattr_tracking
=
False
self
.
_config_dict
=
{
'vocab_size'
:
vocab_size
,
'hidden_size'
:
hidden_size
,
'num_layers'
:
num_layers
,
'num_attention_heads'
:
num_attention_heads
,
'max_sequence_length'
:
max_sequence_length
,
'type_vocab_size'
:
type_vocab_size
,
'intermediate_size'
:
intermediate_size
,
'block_size'
:
block_size
,
'num_rand_blocks'
:
num_rand_blocks
,
'activation'
:
tf
.
keras
.
activations
.
serialize
(
activation
),
'dropout_rate'
:
dropout_rate
,
'attention_dropout_rate'
:
attention_dropout_rate
,
'initializer'
:
tf
.
keras
.
initializers
.
serialize
(
initializer
),
'embedding_width'
:
embedding_width
,
}
word_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
,
name
=
'input_word_ids'
)
mask
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
,
name
=
'input_mask'
)
type_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
,
name
=
'input_type_ids'
)
if
embedding_width
is
None
:
embedding_width
=
hidden_size
self
.
_embedding_layer
=
keras_nlp
.
layers
.
OnDeviceEmbedding
(
vocab_size
=
vocab_size
,
embedding_width
=
embedding_width
,
initializer
=
initializer
,
name
=
'word_embeddings'
)
word_embeddings
=
self
.
_embedding_layer
(
word_ids
)
# Always uses dynamic slicing for simplicity.
self
.
_position_embedding_layer
=
keras_nlp
.
layers
.
PositionEmbedding
(
initializer
=
initializer
,
max_length
=
max_sequence_length
,
name
=
'position_embedding'
)
position_embeddings
=
self
.
_position_embedding_layer
(
word_embeddings
)
self
.
_type_embedding_layer
=
keras_nlp
.
layers
.
OnDeviceEmbedding
(
vocab_size
=
type_vocab_size
,
embedding_width
=
embedding_width
,
initializer
=
initializer
,
use_one_hot
=
True
,
name
=
'type_embeddings'
)
type_embeddings
=
self
.
_type_embedding_layer
(
type_ids
)
embeddings
=
tf
.
keras
.
layers
.
Add
()(
[
word_embeddings
,
position_embeddings
,
type_embeddings
])
self
.
_embedding_norm_layer
=
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
'embeddings/layer_norm'
,
axis
=-
1
,
epsilon
=
1e-12
,
dtype
=
tf
.
float32
)
embeddings
=
self
.
_embedding_norm_layer
(
embeddings
)
embeddings
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
dropout_rate
)(
embeddings
)
# We project the 'embedding' output to 'hidden_size' if it is not already
# 'hidden_size'.
if
embedding_width
!=
hidden_size
:
self
.
_embedding_projection
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
'...x,xy->...y'
,
output_shape
=
hidden_size
,
bias_axes
=
'y'
,
kernel_initializer
=
initializer
,
name
=
'embedding_projection'
)
embeddings
=
self
.
_embedding_projection
(
embeddings
)
self
.
_transformer_layers
=
[]
data
=
embeddings
masks
=
attention
.
BigBirdMasks
(
block_size
=
block_size
)(
mask
)
encoder_outputs
=
[]
attn_head_dim
=
hidden_size
//
num_attention_heads
for
i
in
range
(
num_layers
):
layer
=
layers
.
TransformerScaffold
(
num_attention_heads
,
intermediate_size
,
activation
,
attention_cls
=
attention
.
BigBirdAttention
,
attention_cfg
=
dict
(
num_heads
=
num_attention_heads
,
key_dim
=
attn_head_dim
,
kernel_initializer
=
initializer
,
from_block_size
=
block_size
,
to_block_size
=
block_size
,
num_rand_blocks
=
num_rand_blocks
,
max_rand_mask_length
=
max_sequence_length
,
seed
=
i
),
dropout_rate
=
dropout_rate
,
attention_dropout_rate
=
dropout_rate
,
kernel_initializer
=
initializer
)
self
.
_transformer_layers
.
append
(
layer
)
data
=
layer
([
data
,
masks
])
encoder_outputs
.
append
(
data
)
outputs
=
dict
(
sequence_output
=
encoder_outputs
[
-
1
],
encoder_outputs
=
encoder_outputs
)
super
().
__init__
(
inputs
=
[
word_ids
,
mask
,
type_ids
],
outputs
=
outputs
,
**
kwargs
)
def
get_embedding_table
(
self
):
return
self
.
_embedding_layer
.
embeddings
def
get_embedding_layer
(
self
):
return
self
.
_embedding_layer
def
get_config
(
self
):
return
self
.
_config_dict
@
property
def
transformer_layers
(
self
):
"""List of Transformer layers in the encoder."""
return
self
.
_transformer_layers
@
property
def
pooler_layer
(
self
):
"""The pooler dense layer after the transformer layers."""
return
self
.
_pooler_layer
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
official/nlp/projects/bigbird/encoder_test.py
0 → 100644
View file @
91ccdb45
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for official.nlp.projects.bigbird.encoder."""
import
numpy
as
np
import
tensorflow
as
tf
from
official.nlp.projects.bigbird
import
encoder
class
BigBirdEncoderTest
(
tf
.
test
.
TestCase
):
def
test_encoder
(
self
):
sequence_length
=
1024
batch_size
=
2
vocab_size
=
1024
network
=
encoder
.
BigBirdEncoder
(
num_layers
=
1
,
vocab_size
=
1024
,
max_sequence_length
=
4096
)
word_id_data
=
np
.
random
.
randint
(
vocab_size
,
size
=
(
batch_size
,
sequence_length
))
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
))
type_id_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
))
outputs
=
network
([
word_id_data
,
mask_data
,
type_id_data
])
self
.
assertEqual
(
outputs
[
"sequence_output"
].
shape
,
(
batch_size
,
sequence_length
,
768
))
def
test_save_restore
(
self
):
sequence_length
=
1024
batch_size
=
2
vocab_size
=
1024
network
=
encoder
.
BigBirdEncoder
(
num_layers
=
1
,
vocab_size
=
1024
,
max_sequence_length
=
4096
)
word_id_data
=
np
.
random
.
randint
(
vocab_size
,
size
=
(
batch_size
,
sequence_length
))
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
))
type_id_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
))
inputs
=
dict
(
input_word_ids
=
word_id_data
,
input_mask
=
mask_data
,
input_type_ids
=
type_id_data
)
ref_outputs
=
network
(
inputs
)
model_path
=
self
.
get_temp_dir
()
+
"/model"
network
.
save
(
model_path
)
loaded
=
tf
.
keras
.
models
.
load_model
(
model_path
)
outputs
=
loaded
(
inputs
)
self
.
assertAllClose
(
outputs
[
"sequence_output"
],
ref_outputs
[
"sequence_output"
])
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment