Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
164ee0e0
Commit
164ee0e0
authored
Sep 26, 2020
by
Le Hou
Committed by
A. Unique TensorFlower
Sep 26, 2020
Browse files
Fix "source_tensor is used before assignment"
PiperOrigin-RevId: 333878584
parent
2ebe7c3c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
31 additions
and
0 deletions
+31
-0
official/nlp/keras_nlp/layers/transformer_encoder_block.py
official/nlp/keras_nlp/layers/transformer_encoder_block.py
+3
-0
official/nlp/keras_nlp/layers/transformer_encoder_block_test.py
...al/nlp/keras_nlp/layers/transformer_encoder_block_test.py
+28
-0
No files found.
official/nlp/keras_nlp/layers/transformer_encoder_block.py
View file @
164ee0e0
...
...
@@ -240,6 +240,9 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
input_tensor
,
attention_mask
=
(
inputs
,
None
)
if
self
.
_output_range
:
if
self
.
_norm_first
:
source_tensor
=
input_tensor
[:,
0
:
self
.
_output_range
,
:]
input_tensor
=
self
.
_attention_layer_norm
(
input_tensor
)
target_tensor
=
input_tensor
[:,
0
:
self
.
_output_range
,
:]
attention_mask
=
attention_mask
[:,
0
:
self
.
_output_range
,
:]
else
:
...
...
official/nlp/keras_nlp/layers/transformer_encoder_block_test.py
View file @
164ee0e0
...
...
@@ -137,6 +137,34 @@ class TransformerEncoderBlockLayerTest(keras_parameterized.TestCase):
self
.
assertAllClose
(
new_output_tensor
,
output_tensor
[:,
0
:
1
,
:],
atol
=
5e-5
,
rtol
=
0.003
)
def
test_layer_output_range_with_pre_norm
(
self
,
transformer_cls
):
test_layer
=
transformer_cls
(
num_attention_heads
=
10
,
inner_dim
=
2048
,
inner_activation
=
'relu'
,
norm_first
=
True
)
sequence_length
=
21
width
=
80
batch_size
=
6
input_data
=
10
*
np
.
random
.
random_sample
(
(
batch_size
,
sequence_length
,
width
))
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
,
sequence_length
))
output_tensor
=
test_layer
([
input_data
,
mask_data
])
# The layer only attends to the first token and outputs the first token
# embeeding.
new_layer
=
transformer_cls
(
num_attention_heads
=
10
,
inner_dim
=
2048
,
inner_activation
=
'relu'
,
output_range
=
1
,
norm_first
=
True
)
_
=
new_layer
([
input_data
,
mask_data
])
new_layer
.
set_weights
(
test_layer
.
get_weights
())
new_output_tensor
=
new_layer
([
input_data
,
mask_data
])
self
.
assertAllClose
(
new_output_tensor
,
output_tensor
[:,
0
:
1
,
:],
atol
=
5e-5
,
rtol
=
0.003
)
def
test_layer_invocation_with_float16_dtype
(
self
,
transformer_cls
):
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
'mixed_float16'
)
test_layer
=
transformer_cls
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment