Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
b1aa44d9
"vscode:/vscode.git/clone" did not exist on "5529ddfb82cefc17d83b3d33e7fa0fc6b1ed9e2c"
Commit
b1aa44d9
authored
Apr 13, 2021
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 368260712
parent
782a0299
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
13 deletions
+21
-13
official/nlp/modeling/layers/transformer_scaffold.py
official/nlp/modeling/layers/transformer_scaffold.py
+18
-11
official/nlp/projects/bigbird/encoder.py
official/nlp/projects/bigbird/encoder.py
+3
-2
No files found.
official/nlp/modeling/layers/transformer_scaffold.py
View file @
b1aa44d9
...
@@ -249,7 +249,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
...
@@ -249,7 +249,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
base_config
=
super
(
TransformerScaffold
,
self
).
get_config
()
base_config
=
super
(
TransformerScaffold
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
call
(
self
,
inputs
):
def
call
(
self
,
inputs
,
training
=
None
):
if
isinstance
(
inputs
,
(
list
,
tuple
))
and
len
(
inputs
)
==
2
:
if
isinstance
(
inputs
,
(
list
,
tuple
))
and
len
(
inputs
)
==
2
:
input_tensor
,
attention_mask
=
inputs
input_tensor
,
attention_mask
=
inputs
else
:
else
:
...
@@ -257,27 +257,31 @@ class TransformerScaffold(tf.keras.layers.Layer):
...
@@ -257,27 +257,31 @@ class TransformerScaffold(tf.keras.layers.Layer):
if
self
.
_norm_first
:
if
self
.
_norm_first
:
source_tensor
=
input_tensor
source_tensor
=
input_tensor
input_tensor
=
self
.
_attention_layer_norm
(
input_tensor
)
input_tensor
=
self
.
_attention_layer_norm
(
input_tensor
,
training
=
training
)
attention_output
=
self
.
_attention_layer
(
attention_output
=
self
.
_attention_layer
(
query
=
input_tensor
,
value
=
input_tensor
,
attention_mask
=
attention_mask
)
query
=
input_tensor
,
value
=
input_tensor
,
attention_mask
=
attention_mask
,
attention_output
=
self
.
_attention_dropout
(
attention_output
)
training
=
training
)
attention_output
=
self
.
_attention_dropout
(
attention_output
,
training
=
training
)
if
self
.
_norm_first
:
if
self
.
_norm_first
:
attention_output
=
source_tensor
+
attention_output
attention_output
=
source_tensor
+
attention_output
else
:
else
:
attention_output
=
self
.
_attention_layer_norm
(
input_tensor
+
attention_output
=
self
.
_attention_layer_norm
(
input_tensor
+
attention_output
)
attention_output
,
training
=
training
)
if
self
.
_norm_first
:
if
self
.
_norm_first
:
source_attention_output
=
attention_output
source_attention_output
=
attention_output
attention_output
=
self
.
_output_layer_norm
(
attention_output
)
attention_output
=
self
.
_output_layer_norm
(
attention_output
,
training
=
training
)
if
self
.
_feedforward_block
is
None
:
if
self
.
_feedforward_block
is
None
:
intermediate_output
=
self
.
_intermediate_dense
(
attention_output
)
intermediate_output
=
self
.
_intermediate_dense
(
attention_output
)
intermediate_output
=
self
.
_intermediate_activation_layer
(
intermediate_output
=
self
.
_intermediate_activation_layer
(
intermediate_output
)
intermediate_output
)
layer_output
=
self
.
_output_dense
(
intermediate_output
)
layer_output
=
self
.
_output_dense
(
intermediate_output
,
training
=
training
)
layer_output
=
self
.
_output_dropout
(
layer_output
)
layer_output
=
self
.
_output_dropout
(
layer_output
,
training
=
training
)
# During mixed precision training, attention_output is from layer norm
# During mixed precision training, attention_output is from layer norm
# and is always fp32 for now. Cast layer_output to fp32 for the subsequent
# and is always fp32 for now. Cast layer_output to fp32 for the subsequent
# add.
# add.
...
@@ -285,14 +289,17 @@ class TransformerScaffold(tf.keras.layers.Layer):
...
@@ -285,14 +289,17 @@ class TransformerScaffold(tf.keras.layers.Layer):
if
self
.
_norm_first
:
if
self
.
_norm_first
:
layer_output
=
source_attention_output
+
layer_output
layer_output
=
source_attention_output
+
layer_output
else
:
else
:
layer_output
=
self
.
_output_layer_norm
(
layer_output
+
attention_output
)
layer_output
=
self
.
_output_layer_norm
(
layer_output
+
attention_output
,
training
=
training
)
else
:
else
:
if
self
.
_norm_first
:
if
self
.
_norm_first
:
# if norm_first, assume the feedforward block will not apply layer norm
# if norm_first, assume the feedforward block will not apply layer norm
layer_output
=
self
.
_feedforward_block
(
attention_output
)
layer_output
=
self
.
_feedforward_block
(
attention_output
,
training
=
training
)
layer_output
+=
source_attention_output
layer_output
+=
source_attention_output
else
:
else
:
# if not norm_first, assume that the feedforwad does apply layer norm
# if not norm_first, assume that the feedforwad does apply layer norm
layer_output
=
self
.
_feedforward_block
(
attention_output
)
layer_output
=
self
.
_feedforward_block
(
attention_output
,
training
=
training
)
return
layer_output
return
layer_output
official/nlp/projects/bigbird/encoder.py
View file @
b1aa44d9
...
@@ -28,7 +28,7 @@ from official.nlp.projects.bigbird import recomputing_dropout
...
@@ -28,7 +28,7 @@ from official.nlp.projects.bigbird import recomputing_dropout
class
RecomputeTransformerLayer
(
layers
.
TransformerScaffold
):
class
RecomputeTransformerLayer
(
layers
.
TransformerScaffold
):
"""Transformer layer that recomputes the forward pass during backpropagation."""
"""Transformer layer that recomputes the forward pass during backpropagation."""
def
call
(
self
,
inputs
):
def
call
(
self
,
inputs
,
training
=
None
):
emb
,
mask
=
inputs
emb
,
mask
=
inputs
def
f
(
*
args
):
def
f
(
*
args
):
# recompute_grad can only handle tensor inputs. so we enumerate the
# recompute_grad can only handle tensor inputs. so we enumerate the
...
@@ -39,7 +39,8 @@ class RecomputeTransformerLayer(layers.TransformerScaffold):
...
@@ -39,7 +39,8 @@ class RecomputeTransformerLayer(layers.TransformerScaffold):
# args[3]: mask[2] = encoder_to_mask
# args[3]: mask[2] = encoder_to_mask
# args[4]: mask[3] = blocked_encoder_mask
# args[4]: mask[3] = blocked_encoder_mask
x
=
super
(
RecomputeTransformerLayer
,
x
=
super
(
RecomputeTransformerLayer
,
self
).
call
([
args
[
0
],
[
args
[
1
],
args
[
2
],
args
[
3
],
args
[
4
]]])
self
).
call
([
args
[
0
],
[
args
[
1
],
args
[
2
],
args
[
3
],
args
[
4
]]],
training
=
training
)
return
x
return
x
f
=
recompute_grad
.
recompute_grad
(
f
)
f
=
recompute_grad
.
recompute_grad
(
f
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment