Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
102f267e
Commit
102f267e
authored
May 17, 2021
by
Frederick Liu
Committed by
A. Unique TensorFlower
May 17, 2021
Browse files
Internal change
PiperOrigin-RevId: 374236491
parent
34731381
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
165 additions
and
62 deletions
+165
-62
official/nlp/configs/encoders.py
official/nlp/configs/encoders.py
+48
-30
official/nlp/modeling/networks/encoder_scaffold.py
official/nlp/modeling/networks/encoder_scaffold.py
+62
-22
official/nlp/modeling/networks/encoder_scaffold_test.py
official/nlp/modeling/networks/encoder_scaffold_test.py
+47
-2
official/nlp/projects/bigbird/attention.py
official/nlp/projects/bigbird/attention.py
+6
-5
official/nlp/projects/bigbird/attention_test.py
official/nlp/projects/bigbird/attention_test.py
+1
-1
official/nlp/projects/bigbird/encoder.py
official/nlp/projects/bigbird/encoder.py
+1
-2
No files found.
official/nlp/configs/encoders.py
View file @
102f267e
...
...
@@ -18,15 +18,15 @@ Includes configurations and factory methods.
"""
from
typing
import
Optional
from
absl
import
logging
import
dataclasses
import
gin
import
tensorflow
as
tf
from
official.modeling
import
hyperparams
from
official.modeling
import
tf_utils
from
official.nlp.modeling
import
layers
from
official.nlp.modeling
import
networks
from
official.nlp.projects.bigbird
import
encoder
as
bigbird_
encoder
from
official.nlp.projects.bigbird
import
attention
as
bigbird_
attention
@
dataclasses
.
dataclass
...
...
@@ -177,15 +177,6 @@ class EncoderConfig(hyperparams.OneOfConfig):
xlnet
:
XLNetEncoderConfig
=
XLNetEncoderConfig
()
ENCODER_CLS
=
{
"bert"
:
networks
.
BertEncoder
,
"mobilebert"
:
networks
.
MobileBERTEncoder
,
"albert"
:
networks
.
AlbertEncoder
,
"bigbird"
:
bigbird_encoder
.
BigBirdEncoder
,
"xlnet"
:
networks
.
XLNetBase
,
}
@
gin
.
configurable
def
build_encoder
(
config
:
EncoderConfig
,
embedding_layer
:
Optional
[
tf
.
keras
.
layers
.
Layer
]
=
None
,
...
...
@@ -205,13 +196,11 @@ def build_encoder(config: EncoderConfig,
Returns:
An encoder instance.
"""
encoder_type
=
config
.
type
encoder_cfg
=
config
.
get
()
encoder_cls
=
encoder_cls
or
ENCODER_CLS
[
encoder_type
]
logging
.
info
(
"Encoder class: %s to build..."
,
encoder_cls
.
__name__
)
if
bypass_config
:
return
encoder_cls
()
if
encoder_cls
.
__name__
==
"EncoderScaffold"
:
encoder_type
=
config
.
type
encoder_cfg
=
config
.
get
()
if
encoder_cls
and
encoder_cls
.
__name__
==
"EncoderScaffold"
:
embedding_cfg
=
dict
(
vocab_size
=
encoder_cfg
.
vocab_size
,
type_vocab_size
=
encoder_cfg
.
type_vocab_size
,
...
...
@@ -243,7 +232,7 @@ def build_encoder(config: EncoderConfig,
return
encoder_cls
(
**
kwargs
)
if
encoder_type
==
"mobilebert"
:
return
e
ncoder
_cls
(
return
networks
.
MobileBERTE
ncoder
(
word_vocab_size
=
encoder_cfg
.
word_vocab_size
,
word_embed_size
=
encoder_cfg
.
word_embed_size
,
type_vocab_size
=
encoder_cfg
.
type_vocab_size
,
...
...
@@ -265,7 +254,7 @@ def build_encoder(config: EncoderConfig,
input_mask_dtype
=
encoder_cfg
.
input_mask_dtype
)
if
encoder_type
==
"albert"
:
return
e
ncoder
_cls
(
return
networks
.
AlbertE
ncoder
(
vocab_size
=
encoder_cfg
.
vocab_size
,
embedding_width
=
encoder_cfg
.
embedding_width
,
hidden_size
=
encoder_cfg
.
hidden_size
,
...
...
@@ -282,26 +271,55 @@ def build_encoder(config: EncoderConfig,
dict_outputs
=
True
)
if
encoder_type
==
"bigbird"
:
return
encoder_cls
(
# TODO(frederickliu): Support use_gradient_checkpointing.
if
encoder_cfg
.
use_gradient_checkpointing
:
raise
ValueError
(
"Gradient checkpointing unsupported at the moment."
)
embedding_cfg
=
dict
(
vocab_size
=
encoder_cfg
.
vocab_size
,
type_vocab_size
=
encoder_cfg
.
type_vocab_size
,
hidden_size
=
encoder_cfg
.
hidden_size
,
num_layers
=
encoder_cfg
.
num_layers
,
max_seq_length
=
encoder_cfg
.
max_position_embeddings
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
encoder_cfg
.
initializer_range
),
dropout_rate
=
encoder_cfg
.
dropout_rate
)
attention_cfg
=
dict
(
num_heads
=
encoder_cfg
.
num_attention_heads
,
key_dim
=
int
(
encoder_cfg
.
hidden_size
//
encoder_cfg
.
num_attention_heads
),
kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
encoder_cfg
.
initializer_range
),
max_rand_mask_length
=
encoder_cfg
.
max_position_embeddings
,
num_rand_blocks
=
encoder_cfg
.
num_rand_blocks
,
from_block_size
=
encoder_cfg
.
block_size
,
to_block_size
=
encoder_cfg
.
block_size
,
)
hidden_cfg
=
dict
(
num_attention_heads
=
encoder_cfg
.
num_attention_heads
,
intermediate_size
=
encoder_cfg
.
intermediate_size
,
activation
=
tf_utils
.
get_activation
(
encoder_cfg
.
hidden_activation
),
intermediate_activation
=
tf_utils
.
get_activation
(
encoder_cfg
.
hidden_activation
),
dropout_rate
=
encoder_cfg
.
dropout_rate
,
attention_dropout_rate
=
encoder_cfg
.
attention_dropout_rate
,
num_rand_blocks
=
encoder_cfg
.
num_rand_blocks
,
block_size
=
encoder_cfg
.
block_size
,
max_position_embeddings
=
encoder_cfg
.
max_position_embeddings
,
type_vocab_size
=
encoder_cfg
.
type_vocab_size
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
encoder_cfg
.
initializer_range
),
embedding_width
=
encoder_cfg
.
embedding_width
,
use_gradient_checkpointing
=
encoder_cfg
.
use_gradient_checkpointing
)
attention_cls
=
bigbird_attention
.
BigBirdAttention
,
attention_cfg
=
attention_cfg
)
kwargs
=
dict
(
embedding_cfg
=
embedding_cfg
,
hidden_cls
=
layers
.
TransformerScaffold
,
hidden_cfg
=
hidden_cfg
,
num_hidden_instances
=
encoder_cfg
.
num_layers
,
mask_cls
=
bigbird_attention
.
BigBirdMasks
,
mask_cfg
=
dict
(
block_size
=
encoder_cfg
.
block_size
),
pooled_output_dim
=
encoder_cfg
.
hidden_size
,
pooler_layer_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
encoder_cfg
.
initializer_range
),
return_all_layer_outputs
=
False
,
dict_outputs
=
True
,
layer_idx_as_attention_seed
=
True
)
return
networks
.
EncoderScaffold
(
**
kwargs
)
if
encoder_type
==
"xlnet"
:
return
encoder_cls
(
return
networks
.
XLNetBase
(
vocab_size
=
encoder_cfg
.
vocab_size
,
num_layers
=
encoder_cfg
.
num_layers
,
hidden_size
=
encoder_cfg
.
hidden_size
,
...
...
@@ -325,7 +343,7 @@ def build_encoder(config: EncoderConfig,
# Uses the default BERTEncoder configuration schema to create the encoder.
# If it does not match, please add a switch branch by the encoder type.
return
e
ncoder
_cls
(
return
networks
.
BertE
ncoder
(
vocab_size
=
encoder_cfg
.
vocab_size
,
hidden_size
=
encoder_cfg
.
hidden_size
,
num_layers
=
encoder_cfg
.
num_layers
,
...
...
official/nlp/modeling/networks/encoder_scaffold.py
View file @
102f267e
...
...
@@ -14,6 +14,7 @@
"""Transformer-based text encoder network."""
# pylint: disable=g-classes-have-attributes
import
copy
import
inspect
from
absl
import
logging
...
...
@@ -86,12 +87,19 @@ class EncoderScaffold(tf.keras.Model):
`dropout_rate`: The overall dropout rate for the transformer layers.
`attention_dropout_rate`: The dropout rate for the attention layers.
`kernel_initializer`: The initializer for the transformer layers.
mask_cls: The class to generate masks passed into hidden_cls() from inputs
and 2D mask indicating positions we can attend to. It is the caller's job
to make sure the output of the mask_layer can be used by hidden_layer.
A mask_cls is usually mapped to a hidden_cls.
mask_cfg: A dict of kwargs pass to mask_cls.
layer_norm_before_pooling: Whether to add a layer norm before the pooling
layer. You probably want to turn this on if you set `norm_first=True` in
transformer layers.
return_all_layer_outputs: Whether to output sequence embedding outputs of
all encoder transformer layers.
dict_outputs: Whether to use a dictionary as the model outputs.
layer_idx_as_attention_seed: Whether to include layer_idx in
attention_cfg in hidden_cfg.
"""
def
__init__
(
self
,
...
...
@@ -104,9 +112,12 @@ class EncoderScaffold(tf.keras.Model):
num_hidden_instances
=
1
,
hidden_cls
=
layers
.
Transformer
,
hidden_cfg
=
None
,
mask_cls
=
keras_nlp
.
layers
.
SelfAttentionMask
,
mask_cfg
=
None
,
layer_norm_before_pooling
=
False
,
return_all_layer_outputs
=
False
,
dict_outputs
=
False
,
layer_idx_as_attention_seed
=
False
,
**
kwargs
):
if
embedding_cls
:
...
...
@@ -169,15 +180,25 @@ class EncoderScaffold(tf.keras.Model):
tf
.
keras
.
layers
.
Dropout
(
rate
=
embedding_cfg
[
'dropout_rate'
])(
embeddings
))
attention_mask
=
keras_nlp
.
layers
.
SelfAttentionMask
()(
embeddings
,
mask
)
mask_cfg
=
{}
if
mask_cfg
is
None
else
mask_cfg
if
inspect
.
isclass
(
mask_cls
):
mask_layer
=
mask_cls
(
**
mask_cfg
)
else
:
mask_layer
=
mask_cls
attention_mask
=
mask_layer
(
embeddings
,
mask
)
data
=
embeddings
layer_output_data
=
[]
hidden_layers
=
[]
for
_
in
range
(
num_hidden_instances
):
hidden_cfg
=
hidden_cfg
if
hidden_cfg
else
{}
for
i
in
range
(
num_hidden_instances
):
if
inspect
.
isclass
(
hidden_cls
):
layer
=
hidden_cls
(
**
hidden_cfg
)
if
hidden_cfg
else
hidden_cls
()
if
hidden_cfg
and
'attention_cfg'
in
hidden_cfg
and
(
layer_idx_as_attention_seed
):
hidden_cfg
=
copy
.
deepcopy
(
hidden_cfg
)
hidden_cfg
[
'attention_cfg'
][
'seed'
]
=
i
layer
=
hidden_cls
(
**
hidden_cfg
)
else
:
layer
=
hidden_cls
data
=
layer
([
data
,
attention_mask
])
...
...
@@ -227,6 +248,8 @@ class EncoderScaffold(tf.keras.Model):
self
.
_hidden_cls
=
hidden_cls
self
.
_hidden_cfg
=
hidden_cfg
self
.
_mask_cls
=
mask_cls
self
.
_mask_cfg
=
mask_cfg
self
.
_num_hidden_instances
=
num_hidden_instances
self
.
_pooled_output_dim
=
pooled_output_dim
self
.
_pooler_layer_initializer
=
pooler_layer_initializer
...
...
@@ -247,6 +270,7 @@ class EncoderScaffold(tf.keras.Model):
if
self
.
_layer_norm_before_pooling
:
self
.
_output_layer_norm
=
output_layer_norm
self
.
_pooler_layer
=
pooler_layer
self
.
_layer_idx_as_attention_seed
=
layer_idx_as_attention_seed
logging
.
info
(
'EncoderScaffold configs: %s'
,
self
.
get_config
())
...
...
@@ -260,32 +284,48 @@ class EncoderScaffold(tf.keras.Model):
'layer_norm_before_pooling'
:
self
.
_layer_norm_before_pooling
,
'return_all_layer_outputs'
:
self
.
_return_all_layer_outputs
,
'dict_outputs'
:
self
.
_dict_outputs
,
'layer_idx_as_attention_seed'
:
self
.
_layer_idx_as_attention_seed
}
cfgs
=
{
'hidden_cfg'
:
self
.
_hidden_cfg
,
'mask_cfg'
:
self
.
_mask_cfg
}
if
self
.
_hidden_cfg
:
config_dict
[
'hidden_cfg'
]
=
{}
for
k
,
v
in
self
.
_hidden_cfg
.
items
():
for
cfg_name
,
cfg
in
cfgs
.
items
():
if
cfg
:
config_dict
[
cfg_name
]
=
{}
for
k
,
v
in
cfg
.
items
():
# `self._hidden_cfg` may contain `class`, e.g., when `hidden_cfg` is
# `TransformerScaffold`,
its
`attention_cls` argument can be a `class`.
# `TransformerScaffold`, `attention_cls` argument can be a `class`.
if
inspect
.
isclass
(
v
):
config_dict
[
'hidden_cfg'
][
k
]
=
tf
.
keras
.
utils
.
get_registered_name
(
v
)
config_dict
[
cfg_name
][
k
]
=
tf
.
keras
.
utils
.
get_registered_name
(
v
)
else
:
config_dict
[
'hidden_cfg'
][
k
]
=
v
config_dict
[
cfg_name
][
k
]
=
v
clss
=
{
'hidden_cls'
:
self
.
_hidden_cls
,
'mask_cls'
:
self
.
_mask_cls
}
if
inspect
.
isclass
(
self
.
_hidden_cls
):
config_dict
[
'hidden_cls_string'
]
=
tf
.
keras
.
utils
.
get_registered_name
(
self
.
_hidden_cls
)
for
cls_name
,
cls
in
clss
.
items
():
if
inspect
.
isclass
(
cls
):
key
=
'{}_string'
.
format
(
cls_name
)
config_dict
[
key
]
=
tf
.
keras
.
utils
.
get_registered_name
(
cls
)
else
:
config_dict
[
'hidden_cls'
]
=
self
.
_hidden_
cls
config_dict
[
cls_name
]
=
cls
config_dict
.
update
(
self
.
_kwargs
)
return
config_dict
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
if
'hidden_cls_string'
in
config
:
config
[
'hidden_cls'
]
=
tf
.
keras
.
utils
.
get_registered_object
(
config
[
'hidden_cls_string'
],
custom_objects
=
custom_objects
)
del
config
[
'hidden_cls_string'
]
cls_names
=
[
'hidden_cls'
,
'mask_cls'
]
for
cls_name
in
cls_names
:
cls_string
=
'{}_string'
.
format
(
cls_name
)
if
cls_string
in
config
:
config
[
cls_name
]
=
tf
.
keras
.
utils
.
get_registered_object
(
config
[
cls_string
],
custom_objects
=
custom_objects
)
del
config
[
cls_string
]
return
cls
(
**
config
)
def
get_embedding_table
(
self
):
...
...
official/nlp/modeling/networks/encoder_scaffold_test.py
View file @
102f267e
...
...
@@ -20,6 +20,7 @@ import tensorflow as tf
from
tensorflow.python.keras
import
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
from
official.modeling
import
activations
from
official.nlp
import
keras_nlp
from
official.nlp.modeling
import
layers
from
official.nlp.modeling.networks
import
encoder_scaffold
...
...
@@ -47,6 +48,30 @@ class ValidatedTransformerLayer(layers.Transformer):
return
config
# Test class that wraps a standard self attention mask layer.
# If this layer is called at any point, the list passed to the config
# object will be filled with a
# boolean 'True'. We register this class as a Keras serializable so we can
# test serialization below.
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"TestOnly"
)
class
ValidatedMaskLayer
(
keras_nlp
.
layers
.
SelfAttentionMask
):
def
__init__
(
self
,
call_list
,
call_class
=
None
,
**
kwargs
):
super
(
ValidatedMaskLayer
,
self
).
__init__
(
**
kwargs
)
self
.
list
=
call_list
self
.
call_class
=
call_class
def
call
(
self
,
inputs
,
mask
):
self
.
list
.
append
(
True
)
return
super
(
ValidatedMaskLayer
,
self
).
call
(
inputs
,
mask
)
def
get_config
(
self
):
config
=
super
(
ValidatedMaskLayer
,
self
).
get_config
()
config
[
"call_list"
]
=
self
.
list
config
[
"call_class"
]
=
tf
.
keras
.
utils
.
get_registered_name
(
self
.
call_class
)
return
config
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"TestLayerOnly"
)
class
TestLayer
(
tf
.
keras
.
layers
.
Layer
):
pass
...
...
@@ -95,6 +120,11 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
"call_list"
:
call_list
}
mask_call_list
=
[]
mask_cfg
=
{
"call_list"
:
mask_call_list
}
# Create a small EncoderScaffold for testing.
test_network
=
encoder_scaffold
.
EncoderScaffold
(
num_hidden_instances
=
num_hidden_instances
,
...
...
@@ -103,6 +133,8 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
stddev
=
0.02
),
hidden_cls
=
ValidatedTransformerLayer
,
hidden_cfg
=
hidden_cfg
,
mask_cls
=
ValidatedMaskLayer
,
mask_cfg
=
mask_cfg
,
embedding_cfg
=
embedding_cfg
,
layer_norm_before_pooling
=
True
,
return_all_layer_outputs
=
return_all_layer_outputs
)
...
...
@@ -530,10 +562,15 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase):
"call_list"
:
call_list
}
mask_call_list
=
[]
mask_cfg
=
{
"call_list"
:
mask_call_list
}
# Create a small EncoderScaffold for testing. This time, we pass an already-
# instantiated layer object.
xformer
=
ValidatedTransformerLayer
(
**
hidden_cfg
)
xmask
=
ValidatedMaskLayer
(
**
mask_cfg
)
test_network
=
encoder_scaffold
.
EncoderScaffold
(
num_hidden_instances
=
3
,
...
...
@@ -541,6 +578,7 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase):
pooler_layer_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
),
hidden_cls
=
xformer
,
mask_cls
=
xmask
,
embedding_cfg
=
embedding_cfg
)
# Create the inputs (note that the first dimension is implicit).
...
...
@@ -603,6 +641,8 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase):
"call_class"
:
TestLayer
}
mask_call_list
=
[]
mask_cfg
=
{
"call_list"
:
mask_call_list
,
"call_class"
:
TestLayer
}
# Create a small EncoderScaffold for testing. This time, we pass an already-
# instantiated layer object.
kwargs
=
dict
(
...
...
@@ -614,11 +654,16 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase):
if
use_hidden_cls_instance
:
xformer
=
ValidatedTransformerLayer
(
**
hidden_cfg
)
xmask
=
ValidatedMaskLayer
(
**
mask_cfg
)
test_network
=
encoder_scaffold
.
EncoderScaffold
(
hidden_cls
=
xformer
,
**
kwargs
)
hidden_cls
=
xformer
,
mask_cls
=
xmask
,
**
kwargs
)
else
:
test_network
=
encoder_scaffold
.
EncoderScaffold
(
hidden_cls
=
ValidatedTransformerLayer
,
hidden_cfg
=
hidden_cfg
,
**
kwargs
)
hidden_cls
=
ValidatedTransformerLayer
,
hidden_cfg
=
hidden_cfg
,
mask_cls
=
ValidatedMaskLayer
,
mask_cfg
=
mask_cfg
,
**
kwargs
)
# Create another network object from the first object's config.
new_network
=
encoder_scaffold
.
EncoderScaffold
.
from_config
(
...
...
official/nlp/projects/bigbird/attention.py
View file @
102f267e
...
...
@@ -375,14 +375,15 @@ class BigBirdMasks(tf.keras.layers.Layer):
super
().
__init__
(
**
kwargs
)
self
.
_block_size
=
block_size
def
call
(
self
,
inputs
):
encoder_shape
=
tf
.
shape
(
inputs
)
def
call
(
self
,
inputs
,
mask
):
encoder_shape
=
tf
.
shape
(
mask
)
mask
=
tf
.
cast
(
mask
,
inputs
.
dtype
)
batch_size
,
seq_length
=
encoder_shape
[
0
],
encoder_shape
[
1
]
# reshape for blocking
blocked_encoder_mask
=
tf
.
reshape
(
inputs
,
(
batch_size
,
seq_length
//
self
.
_block_size
,
self
.
_block_size
))
encoder_from_mask
=
tf
.
reshape
(
inputs
,
(
batch_size
,
1
,
seq_length
,
1
))
encoder_to_mask
=
tf
.
reshape
(
inputs
,
(
batch_size
,
1
,
1
,
seq_length
))
mask
,
(
batch_size
,
seq_length
//
self
.
_block_size
,
self
.
_block_size
))
encoder_from_mask
=
tf
.
reshape
(
mask
,
(
batch_size
,
1
,
seq_length
,
1
))
encoder_to_mask
=
tf
.
reshape
(
mask
,
(
batch_size
,
1
,
1
,
seq_length
))
band_mask
=
create_band_mask_from_inputs
(
blocked_encoder_mask
,
blocked_encoder_mask
)
...
...
official/nlp/projects/bigbird/attention_test.py
View file @
102f267e
...
...
@@ -29,7 +29,6 @@ class BigbirdAttentionTest(tf.test.TestCase):
block_size
=
64
mask_layer
=
attention
.
BigBirdMasks
(
block_size
=
block_size
)
encoder_inputs_mask
=
tf
.
zeros
((
batch_size
,
seq_length
),
dtype
=
tf
.
int32
)
masks
=
mask_layer
(
tf
.
cast
(
encoder_inputs_mask
,
dtype
=
tf
.
float64
))
test_layer
=
attention
.
BigBirdAttention
(
num_heads
=
num_heads
,
key_dim
=
key_dim
,
...
...
@@ -38,6 +37,7 @@ class BigbirdAttentionTest(tf.test.TestCase):
seed
=
0
)
query
=
tf
.
random
.
normal
(
shape
=
(
batch_size
,
seq_length
,
key_dim
))
masks
=
mask_layer
(
query
,
tf
.
cast
(
encoder_inputs_mask
,
dtype
=
tf
.
float64
))
value
=
query
output
=
test_layer
(
query
=
query
,
...
...
official/nlp/projects/bigbird/encoder.py
View file @
102f267e
...
...
@@ -177,8 +177,7 @@ class BigBirdEncoder(tf.keras.Model):
self
.
_transformer_layers
=
[]
data
=
embeddings
masks
=
attention
.
BigBirdMasks
(
block_size
=
block_size
)(
tf
.
cast
(
mask
,
embeddings
.
dtype
))
masks
=
attention
.
BigBirdMasks
(
block_size
=
block_size
)(
data
,
mask
)
encoder_outputs
=
[]
attn_head_dim
=
hidden_size
//
num_attention_heads
for
i
in
range
(
num_layers
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment