Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
9b559ad1
Commit
9b559ad1
authored
Nov 30, 2021
by
Yuexin Wu
Committed by
A. Unique TensorFlower
Nov 30, 2021
Browse files
Allow ReZero to take 3 inputs that can be used for compressing sequence length.
PiperOrigin-RevId: 413261350
parent
c57e975a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
39 additions
and
7 deletions
+39
-7
official/nlp/modeling/layers/rezero_transformer.py
official/nlp/modeling/layers/rezero_transformer.py
+25
-7
official/nlp/modeling/layers/rezero_transformer_test.py
official/nlp/modeling/layers/rezero_transformer_test.py
+14
-0
No files found.
official/nlp/modeling/layers/rezero_transformer.py
View file @
9b559ad1
...
...
@@ -80,8 +80,15 @@ class ReZeroTransformer(tf.keras.layers.Layer):
self
.
_use_layer_norm
=
use_layer_norm
def
build
(
self
,
input_shape
):
input_tensor
=
input_shape
[
0
]
if
len
(
input_shape
)
==
2
else
input_shape
input_tensor_shape
=
tf
.
TensorShape
(
input_tensor
)
if
isinstance
(
input_shape
,
tf
.
TensorShape
):
input_tensor_shape
=
input_shape
elif
isinstance
(
input_shape
,
(
list
,
tuple
)):
input_tensor_shape
=
tf
.
TensorShape
(
input_shape
[
0
])
else
:
raise
ValueError
(
"The type of input shape argument is not supported, got: %s"
%
type
(
input_shape
))
if
len
(
input_tensor_shape
.
as_list
())
!=
3
:
raise
ValueError
(
"TransformerLayer expects a three-dimensional input of "
"shape [batch, sequence, width]."
)
...
...
@@ -198,19 +205,30 @@ class ReZeroTransformer(tf.keras.layers.Layer):
self
.
_rezero_a
.
assign
(
0.
)
def
call
(
self
,
inputs
):
if
isinstance
(
inputs
,
(
list
,
tuple
))
and
len
(
inputs
)
==
2
:
input_tensor
,
attention_mask
=
inputs
if
isinstance
(
inputs
,
(
list
,
tuple
)):
if
len
(
inputs
)
==
2
:
input_tensor
,
attention_mask
=
inputs
key_value
=
None
elif
len
(
inputs
)
==
3
:
input_tensor
,
key_value
,
attention_mask
=
inputs
else
:
raise
ValueError
(
"Unexpected inputs to %s with length at %d"
%
(
self
.
__class__
,
len
(
inputs
)))
else
:
input_tensor
,
attention_mask
=
(
inputs
,
None
)
input_tensor
,
key_value
,
attention_mask
=
(
inputs
,
None
,
None
)
if
self
.
_output_range
:
target_tensor
=
input_tensor
[:,
0
:
self
.
_output_range
,
:]
attention_mask
=
attention_mask
[:,
0
:
self
.
_output_range
,
:]
if
attention_mask
is
not
None
:
attention_mask
=
attention_mask
[:,
0
:
self
.
_output_range
,
:]
else
:
target_tensor
=
input_tensor
if
key_value
is
None
:
key_value
=
input_tensor
attention_output
=
self
.
_attention_layer
(
query
=
target_tensor
,
value
=
input_tensor
,
attention_mask
=
attention_mask
)
query
=
target_tensor
,
value
=
key_value
,
attention_mask
=
attention_mask
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
attention_output
=
target_tensor
+
self
.
_rezero_a
*
attention_output
if
self
.
_use_layer_norm
:
...
...
official/nlp/modeling/layers/rezero_transformer_test.py
View file @
9b559ad1
...
...
@@ -124,6 +124,20 @@ class TransformerWithReZeroLayerTest(keras_parameterized.TestCase):
new_output_tensor
=
new_layer
([
input_data
,
mask_data
])
self
.
assertAllClose
(
new_output_tensor
,
output_tensor
[:,
0
:
1
,
:])
def
test_separate_qkv
(
self
):
test_layer
=
rezero_transformer
.
ReZeroTransformer
(
num_attention_heads
=
2
,
intermediate_size
=
128
,
intermediate_activation
=
'relu'
,
kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
))
# Forward path.
q_tensor
=
tf
.
zeros
([
2
,
4
,
16
],
dtype
=
tf
.
float32
)
kv_tensor
=
tf
.
zeros
([
2
,
8
,
16
],
dtype
=
tf
.
float32
)
dummy_mask
=
tf
.
zeros
([
2
,
4
,
8
],
dtype
=
tf
.
float32
)
inputs
=
[
q_tensor
,
kv_tensor
,
dummy_mask
]
output
=
test_layer
(
inputs
)
self
.
assertEqual
(
output
.
shape
,
q_tensor
.
shape
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment