Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
102f267e
Commit
102f267e
authored
May 17, 2021
by
Frederick Liu
Committed by
A. Unique TensorFlower
May 17, 2021
Browse files
Internal change
PiperOrigin-RevId: 374236491
parent
34731381
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
165 additions
and
62 deletions
+165
-62
official/nlp/configs/encoders.py
official/nlp/configs/encoders.py
+48
-30
official/nlp/modeling/networks/encoder_scaffold.py
official/nlp/modeling/networks/encoder_scaffold.py
+62
-22
official/nlp/modeling/networks/encoder_scaffold_test.py
official/nlp/modeling/networks/encoder_scaffold_test.py
+47
-2
official/nlp/projects/bigbird/attention.py
official/nlp/projects/bigbird/attention.py
+6
-5
official/nlp/projects/bigbird/attention_test.py
official/nlp/projects/bigbird/attention_test.py
+1
-1
official/nlp/projects/bigbird/encoder.py
official/nlp/projects/bigbird/encoder.py
+1
-2
No files found.
official/nlp/configs/encoders.py
View file @
102f267e
...
@@ -18,15 +18,15 @@ Includes configurations and factory methods.
...
@@ -18,15 +18,15 @@ Includes configurations and factory methods.
"""
"""
from
typing
import
Optional
from
typing
import
Optional
from
absl
import
logging
import
dataclasses
import
dataclasses
import
gin
import
gin
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
hyperparams
from
official.modeling
import
hyperparams
from
official.modeling
import
tf_utils
from
official.modeling
import
tf_utils
from
official.nlp.modeling
import
layers
from
official.nlp.modeling
import
networks
from
official.nlp.modeling
import
networks
from
official.nlp.projects.bigbird
import
encoder
as
bigbird_
encoder
from
official.nlp.projects.bigbird
import
attention
as
bigbird_
attention
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
@@ -177,15 +177,6 @@ class EncoderConfig(hyperparams.OneOfConfig):
...
@@ -177,15 +177,6 @@ class EncoderConfig(hyperparams.OneOfConfig):
xlnet
:
XLNetEncoderConfig
=
XLNetEncoderConfig
()
xlnet
:
XLNetEncoderConfig
=
XLNetEncoderConfig
()
ENCODER_CLS
=
{
"bert"
:
networks
.
BertEncoder
,
"mobilebert"
:
networks
.
MobileBERTEncoder
,
"albert"
:
networks
.
AlbertEncoder
,
"bigbird"
:
bigbird_encoder
.
BigBirdEncoder
,
"xlnet"
:
networks
.
XLNetBase
,
}
@
gin
.
configurable
@
gin
.
configurable
def
build_encoder
(
config
:
EncoderConfig
,
def
build_encoder
(
config
:
EncoderConfig
,
embedding_layer
:
Optional
[
tf
.
keras
.
layers
.
Layer
]
=
None
,
embedding_layer
:
Optional
[
tf
.
keras
.
layers
.
Layer
]
=
None
,
...
@@ -205,13 +196,11 @@ def build_encoder(config: EncoderConfig,
...
@@ -205,13 +196,11 @@ def build_encoder(config: EncoderConfig,
Returns:
Returns:
An encoder instance.
An encoder instance.
"""
"""
encoder_type
=
config
.
type
encoder_cfg
=
config
.
get
()
encoder_cls
=
encoder_cls
or
ENCODER_CLS
[
encoder_type
]
logging
.
info
(
"Encoder class: %s to build..."
,
encoder_cls
.
__name__
)
if
bypass_config
:
if
bypass_config
:
return
encoder_cls
()
return
encoder_cls
()
if
encoder_cls
.
__name__
==
"EncoderScaffold"
:
encoder_type
=
config
.
type
encoder_cfg
=
config
.
get
()
if
encoder_cls
and
encoder_cls
.
__name__
==
"EncoderScaffold"
:
embedding_cfg
=
dict
(
embedding_cfg
=
dict
(
vocab_size
=
encoder_cfg
.
vocab_size
,
vocab_size
=
encoder_cfg
.
vocab_size
,
type_vocab_size
=
encoder_cfg
.
type_vocab_size
,
type_vocab_size
=
encoder_cfg
.
type_vocab_size
,
...
@@ -243,7 +232,7 @@ def build_encoder(config: EncoderConfig,
...
@@ -243,7 +232,7 @@ def build_encoder(config: EncoderConfig,
return
encoder_cls
(
**
kwargs
)
return
encoder_cls
(
**
kwargs
)
if
encoder_type
==
"mobilebert"
:
if
encoder_type
==
"mobilebert"
:
return
e
ncoder
_cls
(
return
networks
.
MobileBERTE
ncoder
(
word_vocab_size
=
encoder_cfg
.
word_vocab_size
,
word_vocab_size
=
encoder_cfg
.
word_vocab_size
,
word_embed_size
=
encoder_cfg
.
word_embed_size
,
word_embed_size
=
encoder_cfg
.
word_embed_size
,
type_vocab_size
=
encoder_cfg
.
type_vocab_size
,
type_vocab_size
=
encoder_cfg
.
type_vocab_size
,
...
@@ -265,7 +254,7 @@ def build_encoder(config: EncoderConfig,
...
@@ -265,7 +254,7 @@ def build_encoder(config: EncoderConfig,
input_mask_dtype
=
encoder_cfg
.
input_mask_dtype
)
input_mask_dtype
=
encoder_cfg
.
input_mask_dtype
)
if
encoder_type
==
"albert"
:
if
encoder_type
==
"albert"
:
return
e
ncoder
_cls
(
return
networks
.
AlbertE
ncoder
(
vocab_size
=
encoder_cfg
.
vocab_size
,
vocab_size
=
encoder_cfg
.
vocab_size
,
embedding_width
=
encoder_cfg
.
embedding_width
,
embedding_width
=
encoder_cfg
.
embedding_width
,
hidden_size
=
encoder_cfg
.
hidden_size
,
hidden_size
=
encoder_cfg
.
hidden_size
,
...
@@ -282,26 +271,55 @@ def build_encoder(config: EncoderConfig,
...
@@ -282,26 +271,55 @@ def build_encoder(config: EncoderConfig,
dict_outputs
=
True
)
dict_outputs
=
True
)
if
encoder_type
==
"bigbird"
:
if
encoder_type
==
"bigbird"
:
return
encoder_cls
(
# TODO(frederickliu): Support use_gradient_checkpointing.
if
encoder_cfg
.
use_gradient_checkpointing
:
raise
ValueError
(
"Gradient checkpointing unsupported at the moment."
)
embedding_cfg
=
dict
(
vocab_size
=
encoder_cfg
.
vocab_size
,
vocab_size
=
encoder_cfg
.
vocab_size
,
type_vocab_size
=
encoder_cfg
.
type_vocab_size
,
hidden_size
=
encoder_cfg
.
hidden_size
,
hidden_size
=
encoder_cfg
.
hidden_size
,
num_layers
=
encoder_cfg
.
num_layers
,
max_seq_length
=
encoder_cfg
.
max_position_embeddings
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
encoder_cfg
.
initializer_range
),
dropout_rate
=
encoder_cfg
.
dropout_rate
)
attention_cfg
=
dict
(
num_heads
=
encoder_cfg
.
num_attention_heads
,
key_dim
=
int
(
encoder_cfg
.
hidden_size
//
encoder_cfg
.
num_attention_heads
),
kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
encoder_cfg
.
initializer_range
),
max_rand_mask_length
=
encoder_cfg
.
max_position_embeddings
,
num_rand_blocks
=
encoder_cfg
.
num_rand_blocks
,
from_block_size
=
encoder_cfg
.
block_size
,
to_block_size
=
encoder_cfg
.
block_size
,
)
hidden_cfg
=
dict
(
num_attention_heads
=
encoder_cfg
.
num_attention_heads
,
num_attention_heads
=
encoder_cfg
.
num_attention_heads
,
intermediate_size
=
encoder_cfg
.
intermediate_size
,
intermediate_size
=
encoder_cfg
.
intermediate_size
,
activation
=
tf_utils
.
get_activation
(
encoder_cfg
.
hidden_activation
),
intermediate_activation
=
tf_utils
.
get_activation
(
encoder_cfg
.
hidden_activation
),
dropout_rate
=
encoder_cfg
.
dropout_rate
,
dropout_rate
=
encoder_cfg
.
dropout_rate
,
attention_dropout_rate
=
encoder_cfg
.
attention_dropout_rate
,
attention_dropout_rate
=
encoder_cfg
.
attention_dropout_rate
,
num_rand_blocks
=
encoder_cfg
.
num_rand_blocks
,
kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
block_size
=
encoder_cfg
.
block_size
,
max_position_embeddings
=
encoder_cfg
.
max_position_embeddings
,
type_vocab_size
=
encoder_cfg
.
type_vocab_size
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
encoder_cfg
.
initializer_range
),
stddev
=
encoder_cfg
.
initializer_range
),
embedding_width
=
encoder_cfg
.
embedding_width
,
attention_cls
=
bigbird_attention
.
BigBirdAttention
,
use_gradient_checkpointing
=
encoder_cfg
.
use_gradient_checkpointing
)
attention_cfg
=
attention_cfg
)
kwargs
=
dict
(
embedding_cfg
=
embedding_cfg
,
hidden_cls
=
layers
.
TransformerScaffold
,
hidden_cfg
=
hidden_cfg
,
num_hidden_instances
=
encoder_cfg
.
num_layers
,
mask_cls
=
bigbird_attention
.
BigBirdMasks
,
mask_cfg
=
dict
(
block_size
=
encoder_cfg
.
block_size
),
pooled_output_dim
=
encoder_cfg
.
hidden_size
,
pooler_layer_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
encoder_cfg
.
initializer_range
),
return_all_layer_outputs
=
False
,
dict_outputs
=
True
,
layer_idx_as_attention_seed
=
True
)
return
networks
.
EncoderScaffold
(
**
kwargs
)
if
encoder_type
==
"xlnet"
:
if
encoder_type
==
"xlnet"
:
return
encoder_cls
(
return
networks
.
XLNetBase
(
vocab_size
=
encoder_cfg
.
vocab_size
,
vocab_size
=
encoder_cfg
.
vocab_size
,
num_layers
=
encoder_cfg
.
num_layers
,
num_layers
=
encoder_cfg
.
num_layers
,
hidden_size
=
encoder_cfg
.
hidden_size
,
hidden_size
=
encoder_cfg
.
hidden_size
,
...
@@ -325,7 +343,7 @@ def build_encoder(config: EncoderConfig,
...
@@ -325,7 +343,7 @@ def build_encoder(config: EncoderConfig,
# Uses the default BERTEncoder configuration schema to create the encoder.
# Uses the default BERTEncoder configuration schema to create the encoder.
# If it does not match, please add a switch branch by the encoder type.
# If it does not match, please add a switch branch by the encoder type.
return
e
ncoder
_cls
(
return
networks
.
BertE
ncoder
(
vocab_size
=
encoder_cfg
.
vocab_size
,
vocab_size
=
encoder_cfg
.
vocab_size
,
hidden_size
=
encoder_cfg
.
hidden_size
,
hidden_size
=
encoder_cfg
.
hidden_size
,
num_layers
=
encoder_cfg
.
num_layers
,
num_layers
=
encoder_cfg
.
num_layers
,
...
...
official/nlp/modeling/networks/encoder_scaffold.py
View file @
102f267e
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
"""Transformer-based text encoder network."""
"""Transformer-based text encoder network."""
# pylint: disable=g-classes-have-attributes
# pylint: disable=g-classes-have-attributes
import
copy
import
inspect
import
inspect
from
absl
import
logging
from
absl
import
logging
...
@@ -86,12 +87,19 @@ class EncoderScaffold(tf.keras.Model):
...
@@ -86,12 +87,19 @@ class EncoderScaffold(tf.keras.Model):
`dropout_rate`: The overall dropout rate for the transformer layers.
`dropout_rate`: The overall dropout rate for the transformer layers.
`attention_dropout_rate`: The dropout rate for the attention layers.
`attention_dropout_rate`: The dropout rate for the attention layers.
`kernel_initializer`: The initializer for the transformer layers.
`kernel_initializer`: The initializer for the transformer layers.
mask_cls: The class to generate masks passed into hidden_cls() from inputs
and 2D mask indicating positions we can attend to. It is the caller's job
to make sure the output of the mask_layer can be used by hidden_layer.
A mask_cls is usually mapped to a hidden_cls.
mask_cfg: A dict of kwargs pass to mask_cls.
layer_norm_before_pooling: Whether to add a layer norm before the pooling
layer_norm_before_pooling: Whether to add a layer norm before the pooling
layer. You probably want to turn this on if you set `norm_first=True` in
layer. You probably want to turn this on if you set `norm_first=True` in
transformer layers.
transformer layers.
return_all_layer_outputs: Whether to output sequence embedding outputs of
return_all_layer_outputs: Whether to output sequence embedding outputs of
all encoder transformer layers.
all encoder transformer layers.
dict_outputs: Whether to use a dictionary as the model outputs.
dict_outputs: Whether to use a dictionary as the model outputs.
layer_idx_as_attention_seed: Whether to include layer_idx in
attention_cfg in hidden_cfg.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -104,9 +112,12 @@ class EncoderScaffold(tf.keras.Model):
...
@@ -104,9 +112,12 @@ class EncoderScaffold(tf.keras.Model):
num_hidden_instances
=
1
,
num_hidden_instances
=
1
,
hidden_cls
=
layers
.
Transformer
,
hidden_cls
=
layers
.
Transformer
,
hidden_cfg
=
None
,
hidden_cfg
=
None
,
mask_cls
=
keras_nlp
.
layers
.
SelfAttentionMask
,
mask_cfg
=
None
,
layer_norm_before_pooling
=
False
,
layer_norm_before_pooling
=
False
,
return_all_layer_outputs
=
False
,
return_all_layer_outputs
=
False
,
dict_outputs
=
False
,
dict_outputs
=
False
,
layer_idx_as_attention_seed
=
False
,
**
kwargs
):
**
kwargs
):
if
embedding_cls
:
if
embedding_cls
:
...
@@ -169,15 +180,25 @@ class EncoderScaffold(tf.keras.Model):
...
@@ -169,15 +180,25 @@ class EncoderScaffold(tf.keras.Model):
tf
.
keras
.
layers
.
Dropout
(
tf
.
keras
.
layers
.
Dropout
(
rate
=
embedding_cfg
[
'dropout_rate'
])(
embeddings
))
rate
=
embedding_cfg
[
'dropout_rate'
])(
embeddings
))
attention_mask
=
keras_nlp
.
layers
.
SelfAttentionMask
()(
embeddings
,
mask
)
mask_cfg
=
{}
if
mask_cfg
is
None
else
mask_cfg
if
inspect
.
isclass
(
mask_cls
):
mask_layer
=
mask_cls
(
**
mask_cfg
)
else
:
mask_layer
=
mask_cls
attention_mask
=
mask_layer
(
embeddings
,
mask
)
data
=
embeddings
data
=
embeddings
layer_output_data
=
[]
layer_output_data
=
[]
hidden_layers
=
[]
hidden_layers
=
[]
for
_
in
range
(
num_hidden_instances
):
hidden_cfg
=
hidden_cfg
if
hidden_cfg
else
{}
for
i
in
range
(
num_hidden_instances
):
if
inspect
.
isclass
(
hidden_cls
):
if
inspect
.
isclass
(
hidden_cls
):
layer
=
hidden_cls
(
**
hidden_cfg
)
if
hidden_cfg
else
hidden_cls
()
if
hidden_cfg
and
'attention_cfg'
in
hidden_cfg
and
(
layer_idx_as_attention_seed
):
hidden_cfg
=
copy
.
deepcopy
(
hidden_cfg
)
hidden_cfg
[
'attention_cfg'
][
'seed'
]
=
i
layer
=
hidden_cls
(
**
hidden_cfg
)
else
:
else
:
layer
=
hidden_cls
layer
=
hidden_cls
data
=
layer
([
data
,
attention_mask
])
data
=
layer
([
data
,
attention_mask
])
...
@@ -227,6 +248,8 @@ class EncoderScaffold(tf.keras.Model):
...
@@ -227,6 +248,8 @@ class EncoderScaffold(tf.keras.Model):
self
.
_hidden_cls
=
hidden_cls
self
.
_hidden_cls
=
hidden_cls
self
.
_hidden_cfg
=
hidden_cfg
self
.
_hidden_cfg
=
hidden_cfg
self
.
_mask_cls
=
mask_cls
self
.
_mask_cfg
=
mask_cfg
self
.
_num_hidden_instances
=
num_hidden_instances
self
.
_num_hidden_instances
=
num_hidden_instances
self
.
_pooled_output_dim
=
pooled_output_dim
self
.
_pooled_output_dim
=
pooled_output_dim
self
.
_pooler_layer_initializer
=
pooler_layer_initializer
self
.
_pooler_layer_initializer
=
pooler_layer_initializer
...
@@ -247,6 +270,7 @@ class EncoderScaffold(tf.keras.Model):
...
@@ -247,6 +270,7 @@ class EncoderScaffold(tf.keras.Model):
if
self
.
_layer_norm_before_pooling
:
if
self
.
_layer_norm_before_pooling
:
self
.
_output_layer_norm
=
output_layer_norm
self
.
_output_layer_norm
=
output_layer_norm
self
.
_pooler_layer
=
pooler_layer
self
.
_pooler_layer
=
pooler_layer
self
.
_layer_idx_as_attention_seed
=
layer_idx_as_attention_seed
logging
.
info
(
'EncoderScaffold configs: %s'
,
self
.
get_config
())
logging
.
info
(
'EncoderScaffold configs: %s'
,
self
.
get_config
())
...
@@ -260,32 +284,48 @@ class EncoderScaffold(tf.keras.Model):
...
@@ -260,32 +284,48 @@ class EncoderScaffold(tf.keras.Model):
'layer_norm_before_pooling'
:
self
.
_layer_norm_before_pooling
,
'layer_norm_before_pooling'
:
self
.
_layer_norm_before_pooling
,
'return_all_layer_outputs'
:
self
.
_return_all_layer_outputs
,
'return_all_layer_outputs'
:
self
.
_return_all_layer_outputs
,
'dict_outputs'
:
self
.
_dict_outputs
,
'dict_outputs'
:
self
.
_dict_outputs
,
'layer_idx_as_attention_seed'
:
self
.
_layer_idx_as_attention_seed
}
}
if
self
.
_hidden_cfg
:
cfgs
=
{
config_dict
[
'hidden_cfg'
]
=
{}
'hidden_cfg'
:
self
.
_hidden_cfg
,
for
k
,
v
in
self
.
_hidden_cfg
.
items
():
'mask_cfg'
:
self
.
_mask_cfg
# `self._hidden_cfg` may contain `class`, e.g., when `hidden_cfg` is
}
# `TransformerScaffold`, its `attention_cls` argument can be a `class`.
if
inspect
.
isclass
(
v
):
for
cfg_name
,
cfg
in
cfgs
.
items
():
config_dict
[
'hidden_cfg'
][
k
]
=
tf
.
keras
.
utils
.
get_registered_name
(
v
)
if
cfg
:
else
:
config_dict
[
cfg_name
]
=
{}
config_dict
[
'hidden_cfg'
][
k
]
=
v
for
k
,
v
in
cfg
.
items
():
# `self._hidden_cfg` may contain `class`, e.g., when `hidden_cfg` is
if
inspect
.
isclass
(
self
.
_hidden_cls
):
# `TransformerScaffold`, `attention_cls` argument can be a `class`.
config_dict
[
'hidden_cls_string'
]
=
tf
.
keras
.
utils
.
get_registered_name
(
if
inspect
.
isclass
(
v
):
self
.
_hidden_cls
)
config_dict
[
cfg_name
][
k
]
=
tf
.
keras
.
utils
.
get_registered_name
(
v
)
else
:
else
:
config_dict
[
'hidden_cls'
]
=
self
.
_hidden_cls
config_dict
[
cfg_name
][
k
]
=
v
clss
=
{
'hidden_cls'
:
self
.
_hidden_cls
,
'mask_cls'
:
self
.
_mask_cls
}
for
cls_name
,
cls
in
clss
.
items
():
if
inspect
.
isclass
(
cls
):
key
=
'{}_string'
.
format
(
cls_name
)
config_dict
[
key
]
=
tf
.
keras
.
utils
.
get_registered_name
(
cls
)
else
:
config_dict
[
cls_name
]
=
cls
config_dict
.
update
(
self
.
_kwargs
)
config_dict
.
update
(
self
.
_kwargs
)
return
config_dict
return
config_dict
@
classmethod
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
if
'hidden_cls_string'
in
config
:
cls_names
=
[
'hidden_cls'
,
'mask_cls'
]
config
[
'hidden_cls'
]
=
tf
.
keras
.
utils
.
get_registered_object
(
for
cls_name
in
cls_names
:
config
[
'hidden_cls_string'
],
custom_objects
=
custom_objects
)
cls_string
=
'{}_string'
.
format
(
cls_name
)
del
config
[
'hidden_cls_string'
]
if
cls_string
in
config
:
config
[
cls_name
]
=
tf
.
keras
.
utils
.
get_registered_object
(
config
[
cls_string
],
custom_objects
=
custom_objects
)
del
config
[
cls_string
]
return
cls
(
**
config
)
return
cls
(
**
config
)
def
get_embedding_table
(
self
):
def
get_embedding_table
(
self
):
...
...
official/nlp/modeling/networks/encoder_scaffold_test.py
View file @
102f267e
...
@@ -20,6 +20,7 @@ import tensorflow as tf
...
@@ -20,6 +20,7 @@ import tensorflow as tf
from
tensorflow.python.keras
import
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
from
tensorflow.python.keras
import
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
from
official.modeling
import
activations
from
official.modeling
import
activations
from
official.nlp
import
keras_nlp
from
official.nlp.modeling
import
layers
from
official.nlp.modeling
import
layers
from
official.nlp.modeling.networks
import
encoder_scaffold
from
official.nlp.modeling.networks
import
encoder_scaffold
...
@@ -47,6 +48,30 @@ class ValidatedTransformerLayer(layers.Transformer):
...
@@ -47,6 +48,30 @@ class ValidatedTransformerLayer(layers.Transformer):
return
config
return
config
# Test class that wraps a standard self attention mask layer.
# If this layer is called at any point, the list passed to the config
# object will be filled with a
# boolean 'True'. We register this class as a Keras serializable so we can
# test serialization below.
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"TestOnly"
)
class
ValidatedMaskLayer
(
keras_nlp
.
layers
.
SelfAttentionMask
):
def
__init__
(
self
,
call_list
,
call_class
=
None
,
**
kwargs
):
super
(
ValidatedMaskLayer
,
self
).
__init__
(
**
kwargs
)
self
.
list
=
call_list
self
.
call_class
=
call_class
def
call
(
self
,
inputs
,
mask
):
self
.
list
.
append
(
True
)
return
super
(
ValidatedMaskLayer
,
self
).
call
(
inputs
,
mask
)
def
get_config
(
self
):
config
=
super
(
ValidatedMaskLayer
,
self
).
get_config
()
config
[
"call_list"
]
=
self
.
list
config
[
"call_class"
]
=
tf
.
keras
.
utils
.
get_registered_name
(
self
.
call_class
)
return
config
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"TestLayerOnly"
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"TestLayerOnly"
)
class
TestLayer
(
tf
.
keras
.
layers
.
Layer
):
class
TestLayer
(
tf
.
keras
.
layers
.
Layer
):
pass
pass
...
@@ -95,6 +120,11 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
...
@@ -95,6 +120,11 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
"call_list"
:
"call_list"
:
call_list
call_list
}
}
mask_call_list
=
[]
mask_cfg
=
{
"call_list"
:
mask_call_list
}
# Create a small EncoderScaffold for testing.
# Create a small EncoderScaffold for testing.
test_network
=
encoder_scaffold
.
EncoderScaffold
(
test_network
=
encoder_scaffold
.
EncoderScaffold
(
num_hidden_instances
=
num_hidden_instances
,
num_hidden_instances
=
num_hidden_instances
,
...
@@ -103,6 +133,8 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
...
@@ -103,6 +133,8 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
stddev
=
0.02
),
stddev
=
0.02
),
hidden_cls
=
ValidatedTransformerLayer
,
hidden_cls
=
ValidatedTransformerLayer
,
hidden_cfg
=
hidden_cfg
,
hidden_cfg
=
hidden_cfg
,
mask_cls
=
ValidatedMaskLayer
,
mask_cfg
=
mask_cfg
,
embedding_cfg
=
embedding_cfg
,
embedding_cfg
=
embedding_cfg
,
layer_norm_before_pooling
=
True
,
layer_norm_before_pooling
=
True
,
return_all_layer_outputs
=
return_all_layer_outputs
)
return_all_layer_outputs
=
return_all_layer_outputs
)
...
@@ -530,10 +562,15 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase):
...
@@ -530,10 +562,15 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase):
"call_list"
:
"call_list"
:
call_list
call_list
}
}
mask_call_list
=
[]
mask_cfg
=
{
"call_list"
:
mask_call_list
}
# Create a small EncoderScaffold for testing. This time, we pass an already-
# Create a small EncoderScaffold for testing. This time, we pass an already-
# instantiated layer object.
# instantiated layer object.
xformer
=
ValidatedTransformerLayer
(
**
hidden_cfg
)
xformer
=
ValidatedTransformerLayer
(
**
hidden_cfg
)
xmask
=
ValidatedMaskLayer
(
**
mask_cfg
)
test_network
=
encoder_scaffold
.
EncoderScaffold
(
test_network
=
encoder_scaffold
.
EncoderScaffold
(
num_hidden_instances
=
3
,
num_hidden_instances
=
3
,
...
@@ -541,6 +578,7 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase):
...
@@ -541,6 +578,7 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase):
pooler_layer_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
pooler_layer_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
),
stddev
=
0.02
),
hidden_cls
=
xformer
,
hidden_cls
=
xformer
,
mask_cls
=
xmask
,
embedding_cfg
=
embedding_cfg
)
embedding_cfg
=
embedding_cfg
)
# Create the inputs (note that the first dimension is implicit).
# Create the inputs (note that the first dimension is implicit).
...
@@ -603,6 +641,8 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase):
...
@@ -603,6 +641,8 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase):
"call_class"
:
"call_class"
:
TestLayer
TestLayer
}
}
mask_call_list
=
[]
mask_cfg
=
{
"call_list"
:
mask_call_list
,
"call_class"
:
TestLayer
}
# Create a small EncoderScaffold for testing. This time, we pass an already-
# Create a small EncoderScaffold for testing. This time, we pass an already-
# instantiated layer object.
# instantiated layer object.
kwargs
=
dict
(
kwargs
=
dict
(
...
@@ -614,11 +654,16 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase):
...
@@ -614,11 +654,16 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase):
if
use_hidden_cls_instance
:
if
use_hidden_cls_instance
:
xformer
=
ValidatedTransformerLayer
(
**
hidden_cfg
)
xformer
=
ValidatedTransformerLayer
(
**
hidden_cfg
)
xmask
=
ValidatedMaskLayer
(
**
mask_cfg
)
test_network
=
encoder_scaffold
.
EncoderScaffold
(
test_network
=
encoder_scaffold
.
EncoderScaffold
(
hidden_cls
=
xformer
,
**
kwargs
)
hidden_cls
=
xformer
,
mask_cls
=
xmask
,
**
kwargs
)
else
:
else
:
test_network
=
encoder_scaffold
.
EncoderScaffold
(
test_network
=
encoder_scaffold
.
EncoderScaffold
(
hidden_cls
=
ValidatedTransformerLayer
,
hidden_cfg
=
hidden_cfg
,
**
kwargs
)
hidden_cls
=
ValidatedTransformerLayer
,
hidden_cfg
=
hidden_cfg
,
mask_cls
=
ValidatedMaskLayer
,
mask_cfg
=
mask_cfg
,
**
kwargs
)
# Create another network object from the first object's config.
# Create another network object from the first object's config.
new_network
=
encoder_scaffold
.
EncoderScaffold
.
from_config
(
new_network
=
encoder_scaffold
.
EncoderScaffold
.
from_config
(
...
...
official/nlp/projects/bigbird/attention.py
View file @
102f267e
...
@@ -375,14 +375,15 @@ class BigBirdMasks(tf.keras.layers.Layer):
...
@@ -375,14 +375,15 @@ class BigBirdMasks(tf.keras.layers.Layer):
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
self
.
_block_size
=
block_size
self
.
_block_size
=
block_size
def
call
(
self
,
inputs
):
def
call
(
self
,
inputs
,
mask
):
encoder_shape
=
tf
.
shape
(
inputs
)
encoder_shape
=
tf
.
shape
(
mask
)
mask
=
tf
.
cast
(
mask
,
inputs
.
dtype
)
batch_size
,
seq_length
=
encoder_shape
[
0
],
encoder_shape
[
1
]
batch_size
,
seq_length
=
encoder_shape
[
0
],
encoder_shape
[
1
]
# reshape for blocking
# reshape for blocking
blocked_encoder_mask
=
tf
.
reshape
(
blocked_encoder_mask
=
tf
.
reshape
(
inputs
,
(
batch_size
,
seq_length
//
self
.
_block_size
,
self
.
_block_size
))
mask
,
(
batch_size
,
seq_length
//
self
.
_block_size
,
self
.
_block_size
))
encoder_from_mask
=
tf
.
reshape
(
inputs
,
(
batch_size
,
1
,
seq_length
,
1
))
encoder_from_mask
=
tf
.
reshape
(
mask
,
(
batch_size
,
1
,
seq_length
,
1
))
encoder_to_mask
=
tf
.
reshape
(
inputs
,
(
batch_size
,
1
,
1
,
seq_length
))
encoder_to_mask
=
tf
.
reshape
(
mask
,
(
batch_size
,
1
,
1
,
seq_length
))
band_mask
=
create_band_mask_from_inputs
(
blocked_encoder_mask
,
band_mask
=
create_band_mask_from_inputs
(
blocked_encoder_mask
,
blocked_encoder_mask
)
blocked_encoder_mask
)
...
...
official/nlp/projects/bigbird/attention_test.py
View file @
102f267e
...
@@ -29,7 +29,6 @@ class BigbirdAttentionTest(tf.test.TestCase):
...
@@ -29,7 +29,6 @@ class BigbirdAttentionTest(tf.test.TestCase):
block_size
=
64
block_size
=
64
mask_layer
=
attention
.
BigBirdMasks
(
block_size
=
block_size
)
mask_layer
=
attention
.
BigBirdMasks
(
block_size
=
block_size
)
encoder_inputs_mask
=
tf
.
zeros
((
batch_size
,
seq_length
),
dtype
=
tf
.
int32
)
encoder_inputs_mask
=
tf
.
zeros
((
batch_size
,
seq_length
),
dtype
=
tf
.
int32
)
masks
=
mask_layer
(
tf
.
cast
(
encoder_inputs_mask
,
dtype
=
tf
.
float64
))
test_layer
=
attention
.
BigBirdAttention
(
test_layer
=
attention
.
BigBirdAttention
(
num_heads
=
num_heads
,
num_heads
=
num_heads
,
key_dim
=
key_dim
,
key_dim
=
key_dim
,
...
@@ -38,6 +37,7 @@ class BigbirdAttentionTest(tf.test.TestCase):
...
@@ -38,6 +37,7 @@ class BigbirdAttentionTest(tf.test.TestCase):
seed
=
0
)
seed
=
0
)
query
=
tf
.
random
.
normal
(
query
=
tf
.
random
.
normal
(
shape
=
(
batch_size
,
seq_length
,
key_dim
))
shape
=
(
batch_size
,
seq_length
,
key_dim
))
masks
=
mask_layer
(
query
,
tf
.
cast
(
encoder_inputs_mask
,
dtype
=
tf
.
float64
))
value
=
query
value
=
query
output
=
test_layer
(
output
=
test_layer
(
query
=
query
,
query
=
query
,
...
...
official/nlp/projects/bigbird/encoder.py
View file @
102f267e
...
@@ -177,8 +177,7 @@ class BigBirdEncoder(tf.keras.Model):
...
@@ -177,8 +177,7 @@ class BigBirdEncoder(tf.keras.Model):
self
.
_transformer_layers
=
[]
self
.
_transformer_layers
=
[]
data
=
embeddings
data
=
embeddings
masks
=
attention
.
BigBirdMasks
(
block_size
=
block_size
)(
masks
=
attention
.
BigBirdMasks
(
block_size
=
block_size
)(
data
,
mask
)
tf
.
cast
(
mask
,
embeddings
.
dtype
))
encoder_outputs
=
[]
encoder_outputs
=
[]
attn_head_dim
=
hidden_size
//
num_attention_heads
attn_head_dim
=
hidden_size
//
num_attention_heads
for
i
in
range
(
num_layers
):
for
i
in
range
(
num_layers
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment