Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
a3be7365
Commit
a3be7365
authored
Apr 01, 2020
by
George Karpenkov
Committed by
A. Unique TensorFlower
Apr 01, 2020
Browse files
Add a CompiledTransformer layer, which is compiled with XLA.
PiperOrigin-RevId: 304222530
parent
eb6e819d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
27 additions
and
16 deletions
+27
-16
official/nlp/modeling/layers/transformer.py
official/nlp/modeling/layers/transformer.py
+8
-0
official/nlp/modeling/layers/transformer_test.py
official/nlp/modeling/layers/transformer_test.py
+19
-16
No files found.
official/nlp/modeling/layers/transformer.py
View file @
a3be7365
...
@@ -23,6 +23,7 @@ import tensorflow as tf
...
@@ -23,6 +23,7 @@ import tensorflow as tf
from
official.nlp.modeling.layers
import
attention
from
official.nlp.modeling.layers
import
attention
from
official.nlp.modeling.layers
import
dense_einsum
from
official.nlp.modeling.layers
import
dense_einsum
from
official.nlp.modeling.layers.util
import
tf_function_if_eager
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
...
@@ -219,3 +220,10 @@ class Transformer(tf.keras.layers.Layer):
...
@@ -219,3 +220,10 @@ class Transformer(tf.keras.layers.Layer):
layer_output
=
self
.
_output_layer_norm
(
layer_output
+
attention_output
)
layer_output
=
self
.
_output_layer_norm
(
layer_output
+
attention_output
)
return
layer_output
return
layer_output
class
CompiledTransformer
(
Transformer
):
@
tf_function_if_eager
(
experimental_compile
=
True
)
def
call
(
self
,
inputs
):
return
super
(
CompiledTransformer
,
self
).
call
(
inputs
)
official/nlp/modeling/layers/transformer_test.py
View file @
a3be7365
...
@@ -18,6 +18,7 @@ from __future__ import absolute_import
...
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
from
absl.testing
import
parameterized
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -28,14 +29,16 @@ from official.nlp.modeling.layers import transformer
...
@@ -28,14 +29,16 @@ from official.nlp.modeling.layers import transformer
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
# guarantees forward compatibility of this code for the V2 switchover.
@
keras_parameterized
.
run_all_keras_modes
@
keras_parameterized
.
run_all_keras_modes
@
parameterized
.
parameters
(
transformer
.
Transformer
,
transformer
.
CompiledTransformer
)
class
TransformerLayerTest
(
keras_parameterized
.
TestCase
):
class
TransformerLayerTest
(
keras_parameterized
.
TestCase
):
def
tearDown
(
self
):
def
tearDown
(
self
):
super
(
TransformerLayerTest
,
self
).
tearDown
()
super
(
TransformerLayerTest
,
self
).
tearDown
()
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
'float32'
)
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
'float32'
)
def
test_layer_creation
(
self
):
def
test_layer_creation
(
self
,
transformer_cls
):
test_layer
=
transformer
.
Transformer
(
test_layer
=
transformer
_cls
(
num_attention_heads
=
10
,
num_attention_heads
=
10
,
intermediate_size
=
2048
,
intermediate_size
=
2048
,
intermediate_activation
=
'relu'
)
intermediate_activation
=
'relu'
)
...
@@ -47,8 +50,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
...
@@ -47,8 +50,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
# The default output of a transformer layer should be the same as the input.
# The default output of a transformer layer should be the same as the input.
self
.
assertEqual
(
data_tensor
.
shape
.
as_list
(),
output_tensor
.
shape
.
as_list
())
self
.
assertEqual
(
data_tensor
.
shape
.
as_list
(),
output_tensor
.
shape
.
as_list
())
def
test_layer_creation_with_mask
(
self
):
def
test_layer_creation_with_mask
(
self
,
transformer_cls
):
test_layer
=
transformer
.
Transformer
(
test_layer
=
transformer
_cls
(
num_attention_heads
=
10
,
num_attention_heads
=
10
,
intermediate_size
=
2048
,
intermediate_size
=
2048
,
intermediate_activation
=
'relu'
)
intermediate_activation
=
'relu'
)
...
@@ -62,8 +65,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
...
@@ -62,8 +65,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
# The default output of a transformer layer should be the same as the input.
# The default output of a transformer layer should be the same as the input.
self
.
assertEqual
(
data_tensor
.
shape
.
as_list
(),
output_tensor
.
shape
.
as_list
())
self
.
assertEqual
(
data_tensor
.
shape
.
as_list
(),
output_tensor
.
shape
.
as_list
())
def
test_layer_creation_with_incorrect_mask_fails
(
self
):
def
test_layer_creation_with_incorrect_mask_fails
(
self
,
transformer_cls
):
test_layer
=
transformer
.
Transformer
(
test_layer
=
transformer
_cls
(
num_attention_heads
=
10
,
num_attention_heads
=
10
,
intermediate_size
=
2048
,
intermediate_size
=
2048
,
intermediate_activation
=
'relu'
)
intermediate_activation
=
'relu'
)
...
@@ -76,8 +79,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
...
@@ -76,8 +79,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
with
self
.
assertRaisesRegex
(
ValueError
,
'When passing a mask tensor.*'
):
with
self
.
assertRaisesRegex
(
ValueError
,
'When passing a mask tensor.*'
):
_
=
test_layer
([
data_tensor
,
mask_tensor
])
_
=
test_layer
([
data_tensor
,
mask_tensor
])
def
test_layer_invocation
(
self
):
def
test_layer_invocation
(
self
,
transformer_cls
):
test_layer
=
transformer
.
Transformer
(
test_layer
=
transformer
_cls
(
num_attention_heads
=
10
,
num_attention_heads
=
10
,
intermediate_size
=
2048
,
intermediate_size
=
2048
,
intermediate_activation
=
'relu'
)
intermediate_activation
=
'relu'
)
...
@@ -97,8 +100,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
...
@@ -97,8 +100,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
(
batch_size
,
sequence_length
,
width
))
(
batch_size
,
sequence_length
,
width
))
_
=
model
.
predict
(
input_data
)
_
=
model
.
predict
(
input_data
)
def
test_layer_invocation_with_mask
(
self
):
def
test_layer_invocation_with_mask
(
self
,
transformer_cls
):
test_layer
=
transformer
.
Transformer
(
test_layer
=
transformer
_cls
(
num_attention_heads
=
10
,
num_attention_heads
=
10
,
intermediate_size
=
2048
,
intermediate_size
=
2048
,
intermediate_activation
=
'relu'
)
intermediate_activation
=
'relu'
)
...
@@ -124,9 +127,9 @@ class TransformerLayerTest(keras_parameterized.TestCase):
...
@@ -124,9 +127,9 @@ class TransformerLayerTest(keras_parameterized.TestCase):
2
,
size
=
(
batch_size
,
sequence_length
,
sequence_length
))
2
,
size
=
(
batch_size
,
sequence_length
,
sequence_length
))
_
=
model
.
predict
([
input_data
,
mask_data
])
_
=
model
.
predict
([
input_data
,
mask_data
])
def
test_layer_invocation_with_float16_dtype
(
self
):
def
test_layer_invocation_with_float16_dtype
(
self
,
transformer_cls
):
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
'mixed_float16'
)
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
'mixed_float16'
)
test_layer
=
transformer
.
Transformer
(
test_layer
=
transformer
_cls
(
num_attention_heads
=
10
,
num_attention_heads
=
10
,
intermediate_size
=
2048
,
intermediate_size
=
2048
,
intermediate_activation
=
'relu'
)
intermediate_activation
=
'relu'
)
...
@@ -152,8 +155,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
...
@@ -152,8 +155,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
2
,
size
=
(
batch_size
,
sequence_length
,
sequence_length
))
2
,
size
=
(
batch_size
,
sequence_length
,
sequence_length
))
_
=
model
.
predict
([
input_data
,
mask_data
])
_
=
model
.
predict
([
input_data
,
mask_data
])
def
test_transform_with_initializer
(
self
):
def
test_transform_with_initializer
(
self
,
transformer_cls
):
test_layer
=
transformer
.
Transformer
(
test_layer
=
transformer
_cls
(
num_attention_heads
=
10
,
num_attention_heads
=
10
,
intermediate_size
=
2048
,
intermediate_size
=
2048
,
intermediate_activation
=
'relu'
,
intermediate_activation
=
'relu'
,
...
@@ -166,8 +169,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
...
@@ -166,8 +169,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
# The default output of a transformer layer should be the same as the input.
# The default output of a transformer layer should be the same as the input.
self
.
assertEqual
(
data_tensor
.
shape
.
as_list
(),
output
.
shape
.
as_list
())
self
.
assertEqual
(
data_tensor
.
shape
.
as_list
(),
output
.
shape
.
as_list
())
def
test_dynamic_layer_sequence
(
self
):
def
test_dynamic_layer_sequence
(
self
,
transformer_cls
):
test_layer
=
transformer
.
Transformer
(
test_layer
=
transformer
_cls
(
num_attention_heads
=
10
,
num_attention_heads
=
10
,
intermediate_size
=
2048
,
intermediate_size
=
2048
,
intermediate_activation
=
'relu'
,
intermediate_activation
=
'relu'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment