Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
c02bca43
Commit
c02bca43
authored
Jan 28, 2021
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Jan 28, 2021
Browse files
Adds T5/MTF style relative position bias layer.
PiperOrigin-RevId: 354401143
parent
2fba0107
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
169 additions
and
5 deletions
+169
-5
official/nlp/modeling/layers/__init__.py
official/nlp/modeling/layers/__init__.py
+1
-0
official/nlp/modeling/layers/position_embedding.py
official/nlp/modeling/layers/position_embedding.py
+139
-5
official/nlp/modeling/layers/position_embedding_test.py
official/nlp/modeling/layers/position_embedding_test.py
+29
-0
No files found.
official/nlp/modeling/layers/__init__.py
View file @
c02bca43
...
@@ -26,6 +26,7 @@ from official.nlp.modeling.layers.mobile_bert_layers import MobileBertMaskedLM
...
@@ -26,6 +26,7 @@ from official.nlp.modeling.layers.mobile_bert_layers import MobileBertMaskedLM
from
official.nlp.modeling.layers.mobile_bert_layers
import
MobileBertTransformer
from
official.nlp.modeling.layers.mobile_bert_layers
import
MobileBertTransformer
from
official.nlp.modeling.layers.multi_channel_attention
import
*
from
official.nlp.modeling.layers.multi_channel_attention
import
*
from
official.nlp.modeling.layers.on_device_embedding
import
OnDeviceEmbedding
from
official.nlp.modeling.layers.on_device_embedding
import
OnDeviceEmbedding
from
official.nlp.modeling.layers.position_embedding
import
RelativePositionBias
from
official.nlp.modeling.layers.position_embedding
import
RelativePositionEmbedding
from
official.nlp.modeling.layers.position_embedding
import
RelativePositionEmbedding
from
official.nlp.modeling.layers.relative_attention
import
MultiHeadRelativeAttention
from
official.nlp.modeling.layers.relative_attention
import
MultiHeadRelativeAttention
from
official.nlp.modeling.layers.relative_attention
import
TwoStreamRelativeAttention
from
official.nlp.modeling.layers.relative_attention
import
TwoStreamRelativeAttention
...
...
official/nlp/modeling/layers/position_embedding.py
View file @
c02bca43
...
@@ -14,13 +14,15 @@
...
@@ -14,13 +14,15 @@
# ==============================================================================
# ==============================================================================
"""Keras-based positional embedding layer."""
"""Keras-based positional embedding layer."""
# pylint: disable=g-classes-have-attributes
# pylint: disable=g-classes-have-attributes
import
math
import
math
from
typing
import
Optional
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.modeling
import
tf_utils
Initializer
=
tf
.
keras
.
initializers
.
Initializer
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
class
RelativePositionEmbedding
(
tf
.
keras
.
layers
.
Layer
):
class
RelativePositionEmbedding
(
tf
.
keras
.
layers
.
Layer
):
...
@@ -38,9 +40,9 @@ class RelativePositionEmbedding(tf.keras.layers.Layer):
...
@@ -38,9 +40,9 @@ class RelativePositionEmbedding(tf.keras.layers.Layer):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
hidden_size
,
hidden_size
:
int
,
min_timescale
=
1.0
,
min_timescale
:
float
=
1.0
,
max_timescale
=
1.0e4
,
max_timescale
:
float
=
1.0e4
,
**
kwargs
):
**
kwargs
):
# We need to have a default dtype of float32, since the inputs (which Keras
# We need to have a default dtype of float32, since the inputs (which Keras
# usually uses to infer the dtype) will always be int32.
# usually uses to infer the dtype) will always be int32.
...
@@ -50,7 +52,7 @@ class RelativePositionEmbedding(tf.keras.layers.Layer):
...
@@ -50,7 +52,7 @@ class RelativePositionEmbedding(tf.keras.layers.Layer):
if
"dtype"
not
in
kwargs
:
if
"dtype"
not
in
kwargs
:
kwargs
[
"dtype"
]
=
"float32"
kwargs
[
"dtype"
]
=
"float32"
super
(
RelativePositionEmbedding
,
self
).
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
self
.
_hidden_size
=
hidden_size
self
.
_hidden_size
=
hidden_size
self
.
_min_timescale
=
min_timescale
self
.
_min_timescale
=
min_timescale
self
.
_max_timescale
=
max_timescale
self
.
_max_timescale
=
max_timescale
...
@@ -101,3 +103,135 @@ class RelativePositionEmbedding(tf.keras.layers.Layer):
...
@@ -101,3 +103,135 @@ class RelativePositionEmbedding(tf.keras.layers.Layer):
[
tf
.
sin
(
scaled_time
),
tf
.
cos
(
scaled_time
)],
axis
=
1
)
[
tf
.
sin
(
scaled_time
),
tf
.
cos
(
scaled_time
)],
axis
=
1
)
return
position_embeddings
return
position_embeddings
def
_relative_position_bucket
(
relative_position
,
bidirectional
=
True
,
num_buckets
=
32
,
max_distance
=
128
):
"""Translate relative position to a bucket number for relative attention.
The relative position is defined as memory_position - query_position, i.e.
the distance in tokens from the attending position to the attended-to
position.
If bidirectional=False, then positive relative positions are invalid.
We use smaller buckets for small absolute relative_position and larger
buckets for larger absolute relative_positions.
All relative positions >=max_distance map to the same bucket.
All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences
than the model has been trained on.
Args:
relative_position: an int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position, containing int32
values in the range [0, num_buckets)
"""
ret
=
0
n
=
-
relative_position
if
bidirectional
:
num_buckets
//=
2
ret
+=
tf
.
cast
(
tf
.
math
.
less
(
n
,
0
),
tf
.
int32
)
*
num_buckets
n
=
tf
.
math
.
abs
(
n
)
else
:
n
=
tf
.
math
.
maximum
(
n
,
0
)
# now n is in the range [0, inf)
max_exact
=
num_buckets
//
2
is_small
=
tf
.
math
.
less
(
n
,
max_exact
)
val_if_large
=
max_exact
+
tf
.
dtypes
.
cast
(
tf
.
math
.
log
(
tf
.
cast
(
n
,
tf
.
float32
)
/
max_exact
)
/
math
.
log
(
max_distance
/
max_exact
)
*
(
num_buckets
-
max_exact
),
tf
.
int32
,
)
val_if_large
=
tf
.
math
.
minimum
(
val_if_large
,
num_buckets
-
1
)
ret
+=
tf
.
where
(
is_small
,
n
,
val_if_large
)
return
ret
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
class
RelativePositionBias
(
tf
.
keras
.
layers
.
Layer
):
"""Relative position embedding via per-head bias in T5 style.
Reference implementation in MeshTF:
https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L1000
This layer implements the relative position bias used in "Exploring the Limits
of Transfer Learning with a Unified Text-to-Text Transformer"
(https://arxiv.org/abs/1910.10683)
"""
def
__init__
(
self
,
num_heads
:
int
,
relative_attention_num_buckets
:
int
=
32
,
relative_attention_max_distance
:
int
=
128
,
bidirectional
:
bool
=
True
,
embeddings_initializer
:
Optional
[
Initializer
]
=
None
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
num_heads
=
num_heads
self
.
relative_attention_num_buckets
=
relative_attention_num_buckets
self
.
bidirectional
=
bidirectional
self
.
relative_attention_max_distance
=
relative_attention_max_distance
if
embeddings_initializer
:
self
.
_embed_init
=
embeddings_initializer
else
:
self
.
_embed_init
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
1.0
)
with
tf
.
name_scope
(
self
.
name
):
self
.
_relative_attention_bias
=
self
.
add_weight
(
"rel_embedding"
,
shape
=
[
self
.
relative_attention_num_buckets
,
self
.
num_heads
],
initializer
=
self
.
_embed_init
,
dtype
=
self
.
dtype
,
trainable
=
True
)
def
get_config
(
self
):
config
=
{
"num_heads"
:
self
.
num_heads
,
"relative_attention_num_buckets"
:
self
.
relative_attention_num_buckets
,
"relative_attention_max_distance"
:
self
.
relative_attention_max_distance
,
"bidirectional"
:
self
.
bidirectional
,
"embeddings_initializer"
:
tf
.
keras
.
initializers
.
serialize
(
self
.
_embed_init
),
}
base_config
=
super
().
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
call
(
self
,
query
:
tf
.
Tensor
,
key
:
tf
.
Tensor
):
"""Implements the forward pass.
Args:
query: query input tensor shape [batch, query length, hidden size].
key: key input tensor shape [batch, key length, hidden size].
Returns:
A tensor in shape of [batch, heads, query length, key length].
"""
batch_size
,
qlen
=
tf_utils
.
get_shape_list
(
query
)[:
2
]
klen
=
tf_utils
.
get_shape_list
(
key
)[
1
]
context_position
=
tf
.
range
(
qlen
)[:,
None
]
memory_position
=
tf
.
range
(
klen
)[
None
,
:]
relative_position
=
memory_position
-
context_position
rp_bucket
=
_relative_position_bucket
(
relative_position
,
bidirectional
=
self
.
bidirectional
,
num_buckets
=
self
.
relative_attention_num_buckets
,
max_distance
=
self
.
relative_attention_max_distance
)
values
=
tf
.
nn
.
embedding_lookup
(
self
.
_relative_attention_bias
,
rp_bucket
)
values
=
tf
.
expand_dims
(
tf
.
transpose
(
values
,
[
2
,
0
,
1
]),
axis
=
0
)
# shape (1, num_heads, qlen, klen)
values
=
tf
.
tile
(
values
,
[
batch_size
,
1
,
1
,
1
])
return
values
official/nlp/modeling/layers/position_embedding_test.py
View file @
c02bca43
...
@@ -14,6 +14,8 @@
...
@@ -14,6 +14,8 @@
# ==============================================================================
# ==============================================================================
"""Tests for Keras-based positional embedding layer."""
"""Tests for Keras-based positional embedding layer."""
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.python.keras
import
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
from
tensorflow.python.keras
import
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
...
@@ -55,5 +57,32 @@ class RelativePositionEmbeddingLayerTest(keras_parameterized.TestCase):
...
@@ -55,5 +57,32 @@ class RelativePositionEmbeddingLayerTest(keras_parameterized.TestCase):
self
.
assertAllEqual
(
output_tensor
,
expected_output_tensor
)
self
.
assertAllEqual
(
output_tensor
,
expected_output_tensor
)
@
keras_parameterized
.
run_all_keras_modes
class
RelativePositionBiasTest
(
keras_parameterized
.
TestCase
):
@
parameterized
.
named_parameters
((
"bidirectional"
,
True
),
(
"unidirectional"
,
False
))
def
test_relative_position_bias
(
self
,
bidirectional
):
query
=
tf
.
zeros
((
4
,
4
,
2
))
key
=
tf
.
zeros
((
4
,
2
,
2
))
l
=
position_embedding
.
RelativePositionBias
(
num_heads
=
3
,
bidirectional
=
bidirectional
,
name
=
"foo"
)
self
.
assertEqual
(
l
(
query
,
key
).
shape
,
(
4
,
3
,
4
,
2
))
self
.
assertLen
(
l
.
trainable_variables
,
1
)
self
.
assertEqual
(
l
.
trainable_variables
[
0
].
name
,
"foo/rel_embedding:0"
)
def
test_relative_position_bucket
(
self
):
context_position
=
tf
.
range
(
3
)[:,
None
]
memory_position
=
tf
.
range
(
2
)[
None
,
:]
relative_position
=
memory_position
-
context_position
outputs
=
position_embedding
.
_relative_position_bucket
(
relative_position
)
self
.
assertAllEqual
(
outputs
.
numpy
(),
np
.
array
([[
0
,
17
],
[
1
,
0
],
[
2
,
1
]]))
outputs
=
position_embedding
.
_relative_position_bucket
(
relative_position
,
bidirectional
=
False
)
self
.
assertAllEqual
(
outputs
.
numpy
(),
np
.
array
([[
0
,
0
],
[
1
,
0
],
[
2
,
1
]]))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment