Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
842eb979
Commit
842eb979
authored
May 19, 2021
by
Frederick Liu
Committed by
A. Unique TensorFlower
May 19, 2021
Browse files
Internal change
PiperOrigin-RevId: 374640999
parent
c035325f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
105 additions
and
3 deletions
+105
-3
official/nlp/configs/encoders.py
official/nlp/configs/encoders.py
+68
-0
official/nlp/modeling/layers/__init__.py
official/nlp/modeling/layers/__init__.py
+1
-0
official/nlp/modeling/layers/kernel_attention.py
official/nlp/modeling/layers/kernel_attention.py
+36
-3
No files found.
official/nlp/configs/encoders.py
View file @
842eb979
...
@@ -140,6 +140,28 @@ class BigBirdEncoderConfig(hyperparams.Config):
...
@@ -140,6 +140,28 @@ class BigBirdEncoderConfig(hyperparams.Config):
use_gradient_checkpointing
:
bool
=
False
use_gradient_checkpointing
:
bool
=
False
@
dataclasses
.
dataclass
class
KernelEncoderConfig
(
hyperparams
.
Config
):
"""Linear encoder configuration."""
vocab_size
:
int
=
30522
hidden_size
:
int
=
768
num_layers
:
int
=
12
num_attention_heads
:
int
=
12
hidden_activation
:
str
=
"gelu"
intermediate_size
:
int
=
3072
dropout_rate
:
float
=
0.1
attention_dropout_rate
:
float
=
0.1
max_position_embeddings
:
int
=
512
type_vocab_size
:
int
=
2
initializer_range
:
float
=
0.02
embedding_size
:
Optional
[
int
]
=
None
feature_transform
:
str
=
"exp"
num_random_features
:
int
=
256
redraw
:
bool
=
False
is_short_seq
:
bool
=
False
begin_kernel
:
int
=
0
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
XLNetEncoderConfig
(
hyperparams
.
Config
):
class
XLNetEncoderConfig
(
hyperparams
.
Config
):
"""XLNet encoder configuration."""
"""XLNet encoder configuration."""
...
@@ -172,6 +194,7 @@ class EncoderConfig(hyperparams.OneOfConfig):
...
@@ -172,6 +194,7 @@ class EncoderConfig(hyperparams.OneOfConfig):
albert
:
AlbertEncoderConfig
=
AlbertEncoderConfig
()
albert
:
AlbertEncoderConfig
=
AlbertEncoderConfig
()
bert
:
BertEncoderConfig
=
BertEncoderConfig
()
bert
:
BertEncoderConfig
=
BertEncoderConfig
()
bigbird
:
BigBirdEncoderConfig
=
BigBirdEncoderConfig
()
bigbird
:
BigBirdEncoderConfig
=
BigBirdEncoderConfig
()
kernel
:
KernelEncoderConfig
=
KernelEncoderConfig
()
mobilebert
:
MobileBertEncoderConfig
=
MobileBertEncoderConfig
()
mobilebert
:
MobileBertEncoderConfig
=
MobileBertEncoderConfig
()
xlnet
:
XLNetEncoderConfig
=
XLNetEncoderConfig
()
xlnet
:
XLNetEncoderConfig
=
XLNetEncoderConfig
()
...
@@ -317,6 +340,51 @@ def build_encoder(config: EncoderConfig,
...
@@ -317,6 +340,51 @@ def build_encoder(config: EncoderConfig,
layer_idx_as_attention_seed
=
True
)
layer_idx_as_attention_seed
=
True
)
return
networks
.
EncoderScaffold
(
**
kwargs
)
return
networks
.
EncoderScaffold
(
**
kwargs
)
if
encoder_type
==
"kernel"
:
embedding_cfg
=
dict
(
vocab_size
=
encoder_cfg
.
vocab_size
,
type_vocab_size
=
encoder_cfg
.
type_vocab_size
,
hidden_size
=
encoder_cfg
.
hidden_size
,
max_seq_length
=
encoder_cfg
.
max_position_embeddings
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
encoder_cfg
.
initializer_range
),
dropout_rate
=
encoder_cfg
.
dropout_rate
)
attention_cfg
=
dict
(
num_heads
=
encoder_cfg
.
num_attention_heads
,
key_dim
=
int
(
encoder_cfg
.
hidden_size
//
encoder_cfg
.
num_attention_heads
),
kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
encoder_cfg
.
initializer_range
),
feature_transform
=
encoder_cfg
.
feature_transform
,
num_random_features
=
encoder_cfg
.
num_random_features
,
redraw
=
encoder_cfg
.
redraw
,
is_short_seq
=
encoder_cfg
.
is_short_seq
,
begin_kernel
=
encoder_cfg
.
begin_kernel
,
)
hidden_cfg
=
dict
(
num_attention_heads
=
encoder_cfg
.
num_attention_heads
,
intermediate_size
=
encoder_cfg
.
intermediate_size
,
intermediate_activation
=
tf_utils
.
get_activation
(
encoder_cfg
.
hidden_activation
),
dropout_rate
=
encoder_cfg
.
dropout_rate
,
attention_dropout_rate
=
encoder_cfg
.
attention_dropout_rate
,
kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
encoder_cfg
.
initializer_range
),
attention_cls
=
layers
.
KernelAttention
,
attention_cfg
=
attention_cfg
)
kwargs
=
dict
(
embedding_cfg
=
embedding_cfg
,
hidden_cls
=
layers
.
TransformerScaffold
,
hidden_cfg
=
hidden_cfg
,
num_hidden_instances
=
encoder_cfg
.
num_layers
,
mask_cls
=
layers
.
KernelMask
,
pooled_output_dim
=
encoder_cfg
.
hidden_size
,
pooler_layer_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
encoder_cfg
.
initializer_range
),
return_all_layer_outputs
=
False
,
dict_outputs
=
True
,
layer_idx_as_attention_seed
=
True
)
return
networks
.
EncoderScaffold
(
**
kwargs
)
if
encoder_type
==
"xlnet"
:
if
encoder_type
==
"xlnet"
:
return
networks
.
XLNetBase
(
return
networks
.
XLNetBase
(
vocab_size
=
encoder_cfg
.
vocab_size
,
vocab_size
=
encoder_cfg
.
vocab_size
,
...
...
official/nlp/modeling/layers/__init__.py
View file @
842eb979
...
@@ -25,6 +25,7 @@ from official.nlp.modeling.layers.dense_einsum import DenseEinsum
...
@@ -25,6 +25,7 @@ from official.nlp.modeling.layers.dense_einsum import DenseEinsum
from
official.nlp.modeling.layers.gated_feedforward
import
GatedFeedforward
from
official.nlp.modeling.layers.gated_feedforward
import
GatedFeedforward
from
official.nlp.modeling.layers.gaussian_process
import
RandomFeatureGaussianProcess
from
official.nlp.modeling.layers.gaussian_process
import
RandomFeatureGaussianProcess
from
official.nlp.modeling.layers.kernel_attention
import
KernelAttention
from
official.nlp.modeling.layers.kernel_attention
import
KernelAttention
from
official.nlp.modeling.layers.kernel_attention
import
KernelMask
from
official.nlp.modeling.layers.masked_lm
import
MaskedLM
from
official.nlp.modeling.layers.masked_lm
import
MaskedLM
from
official.nlp.modeling.layers.masked_softmax
import
MaskedSoftmax
from
official.nlp.modeling.layers.masked_softmax
import
MaskedSoftmax
from
official.nlp.modeling.layers.mat_mul_with_margin
import
MatMulWithMargin
from
official.nlp.modeling.layers.mat_mul_with_margin
import
MatMulWithMargin
...
...
official/nlp/modeling/layers/kernel_attention.py
View file @
842eb979
...
@@ -21,6 +21,24 @@ import tensorflow as tf
...
@@ -21,6 +21,24 @@ import tensorflow as tf
_NUMERIC_STABLER
=
1e-6
_NUMERIC_STABLER
=
1e-6
class
KernelMask
(
tf
.
keras
.
layers
.
Layer
):
"""Creates kernel attention mask.
inputs: from_tensor: 2D or 3D Tensor of shape
[batch_size, from_seq_length, ...].
mask: a Tensor of shape [batch_size, from_seq_length] which indicates
which part of the inputs we should not attend.
Returns:
float Tensor of shape [batch_size, from_seq_length] that KernelAttention
takes as mask.
"""
def
call
(
self
,
inputs
,
mask
):
mask
=
tf
.
cast
(
mask
,
inputs
.
dtype
)
return
mask
def
create_projection_matrix
(
m
,
d
,
seed
=
None
):
def
create_projection_matrix
(
m
,
d
,
seed
=
None
):
r
"""Constructs the matrix of random projections.
r
"""Constructs the matrix of random projections.
...
@@ -248,7 +266,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -248,7 +266,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
short or long sequences; usually short sequence is defined as having
short or long sequences; usually short sequence is defined as having
length L <= 1024.
length L <= 1024.
attention_mask: a boolean mask of shape `[B, S]`, that prevents
attention_mask: a boolean mask of shape `[B, S]`, that prevents
attenti
o
n to
certain
positions. Note that the mask is only appied to
attentin
g
to
masked
positions. Note that the mask is only appied to
the keys. User may want to mask the output if query contains pads.
the keys. User may want to mask the output if query contains pads.
training: Python boolean indicating whether the layer should behave in
training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (doing nothing).
training mode (adding dropout) or in inference mode (doing nothing).
...
@@ -305,8 +323,23 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -305,8 +323,23 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
value
,
value
,
key
=
None
,
key
=
None
,
attention_mask
=
None
,
attention_mask
=
None
,
training
=
False
,
training
=
False
):
**
kwargs
):
"""Compute attention with kernel mechanism.
Args:
query: Query `Tensor` of shape `[B, T, dim]`.
value: Value `Tensor` of shape `[B, S, dim]`.
key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, S]`, that prevents
attenting to masked positions. Note that the mask is only appied to
the keys. User may want to mask the output if query contains pads.
training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (doing nothing).
Returns:
Multi-headed outputs of attention computation.
"""
if
not
self
.
_built_from_signature
:
if
not
self
.
_built_from_signature
:
self
.
_build_from_signature
(
query
=
query
,
value
=
value
,
key
=
key
)
self
.
_build_from_signature
(
query
=
query
,
value
=
value
,
key
=
key
)
if
key
is
None
:
if
key
is
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment