Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
d182e423
Commit
d182e423
authored
Jun 20, 2022
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Jun 20, 2022
Browse files
Internal change
PiperOrigin-RevId: 456159384
parent
a471eb3b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
136 additions
and
2 deletions
+136
-2
official/nlp/modeling/layers/pack_optimization.py
official/nlp/modeling/layers/pack_optimization.py
+113
-0
official/nlp/modeling/layers/pack_optimization_test.py
official/nlp/modeling/layers/pack_optimization_test.py
+23
-2
No files found.
official/nlp/modeling/layers/pack_optimization.py
View file @
d182e423
...
@@ -16,8 +16,10 @@
...
@@ -16,8 +16,10 @@
from
typing
import
Dict
from
typing
import
Dict
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.modeling
import
tf_utils
from
official.nlp.modeling.layers
import
rezero_transformer
from
official.nlp.modeling.layers
import
self_attention_mask
from
official.nlp.modeling.layers
import
self_attention_mask
from
official.nlp.modeling.layers
import
transformer_encoder_block
from
official.nlp.modeling.layers
import
transformer_encoder_block
from
official.nlp.modeling.layers
import
transformer_scaffold
def
_packing_mask
(
segment_id
,
source_segment_id
,
dtype
=
tf
.
float32
):
def
_packing_mask
(
segment_id
,
source_segment_id
,
dtype
=
tf
.
float32
):
...
@@ -142,3 +144,114 @@ class StridedTransformerEncoderBlock(
...
@@ -142,3 +144,114 @@ class StridedTransformerEncoderBlock(
layer_output
=
tf
.
cast
(
layer_output
,
tf
.
float32
)
layer_output
=
tf
.
cast
(
layer_output
,
tf
.
float32
)
return
self
.
_output_layer_norm
(
layer_output
+
attention_output
)
return
self
.
_output_layer_norm
(
layer_output
+
attention_output
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
class
StridedReZeroTransformer
(
rezero_transformer
.
ReZeroTransformer
):
"""ReZeroTransformer for packing optimization to stride over inputs."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
if
self
.
_output_range
is
not
None
:
raise
ValueError
(
f
'
{
self
.
__class__
}
does not '
'support `output_range` argument.'
)
def
call
(
self
,
inputs
,
stride
:
tf
.
Tensor
):
if
isinstance
(
inputs
,
(
list
,
tuple
)):
if
len
(
inputs
)
==
2
:
input_tensor
,
attention_mask
=
inputs
key_value
=
None
elif
len
(
inputs
)
==
3
:
input_tensor
,
key_value
,
attention_mask
=
inputs
else
:
raise
ValueError
(
f
'Unexpected inputs to
{
self
.
__class__
}
with '
f
'length at
{
len
(
inputs
)
}
.'
)
else
:
input_tensor
,
key_value
,
attention_mask
=
(
inputs
,
None
,
None
)
target_tensor
=
input_tensor
[:,
::
stride
,
:]
if
attention_mask
is
not
None
:
attention_mask
=
attention_mask
[:,
::
stride
,
:]
if
key_value
is
None
:
key_value
=
input_tensor
attention_output
=
self
.
_attention_layer
(
query
=
target_tensor
,
value
=
key_value
,
attention_mask
=
attention_mask
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
attention_output
=
target_tensor
+
self
.
_rezero_a
*
attention_output
if
self
.
_use_layer_norm
:
attention_output
=
self
.
_attention_layer_norm
(
attention_output
)
else
:
attention_output
=
tf
.
cast
(
attention_output
,
tf
.
float32
)
intermediate_output
=
self
.
_intermediate_dense
(
attention_output
)
intermediate_output
=
self
.
_inner_activation_layer
(
intermediate_output
)
layer_output
=
self
.
_output_dense
(
intermediate_output
)
layer_output
=
self
.
_output_dropout
(
layer_output
)
layer_output
=
attention_output
+
tf
.
cast
(
self
.
_rezero_a_ffn
*
layer_output
,
tf
.
float32
)
if
self
.
_use_layer_norm
:
layer_output
=
self
.
_output_layer_norm
(
layer_output
)
return
layer_output
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
class
StridedTransformerScaffold
(
transformer_scaffold
.
TransformerScaffold
):
"""TransformerScaffold for packing optimization to stride over inputs."""
def
call
(
self
,
inputs
,
stride
:
tf
.
Tensor
,
training
=
None
):
if
isinstance
(
inputs
,
(
list
,
tuple
))
and
len
(
inputs
)
==
2
:
input_tensor
,
attention_mask
=
inputs
else
:
input_tensor
,
attention_mask
=
(
inputs
,
None
)
if
self
.
_norm_first
:
source_tensor
=
input_tensor
[:,
::
stride
,
:]
input_tensor
=
self
.
_attention_layer_norm
(
input_tensor
,
training
=
training
)
if
attention_mask
is
not
None
:
attention_mask
=
attention_mask
[:,
::
stride
,
:]
target_tensor
=
input_tensor
[:,
::
stride
,
:]
attention_output
=
self
.
_attention_layer
(
query
=
target_tensor
,
value
=
input_tensor
,
attention_mask
=
attention_mask
,
training
=
training
)
attention_output
=
self
.
_attention_dropout
(
attention_output
,
training
=
training
)
if
self
.
_norm_first
:
attention_output
=
source_tensor
+
attention_output
else
:
attention_output
=
self
.
_attention_layer_norm
(
target_tensor
+
attention_output
,
training
=
training
)
if
self
.
_norm_first
:
source_attention_output
=
attention_output
attention_output
=
self
.
_output_layer_norm
(
attention_output
,
training
=
training
)
if
self
.
_feedforward_block
is
None
:
intermediate_output
=
self
.
_intermediate_dense
(
attention_output
)
intermediate_output
=
self
.
_intermediate_activation_layer
(
intermediate_output
)
layer_output
=
self
.
_output_dense
(
intermediate_output
,
training
=
training
)
layer_output
=
self
.
_output_dropout
(
layer_output
,
training
=
training
)
layer_output
=
tf
.
cast
(
layer_output
,
tf
.
float32
)
if
self
.
_norm_first
:
layer_output
=
source_attention_output
+
layer_output
else
:
layer_output
=
self
.
_output_layer_norm
(
layer_output
+
attention_output
,
training
=
training
)
else
:
if
self
.
_norm_first
:
# if norm_first, assume the feedforward block will not apply layer norm
layer_output
=
self
.
_feedforward_block
(
attention_output
,
training
=
training
)
layer_output
+=
source_attention_output
else
:
# if not norm_first, assume that the feedforwad does apply layer norm
layer_output
=
self
.
_feedforward_block
(
attention_output
,
training
=
training
)
return
layer_output
official/nlp/modeling/layers/pack_optimization_test.py
View file @
d182e423
...
@@ -37,8 +37,29 @@ class PackOptimizationTest(tf.test.TestCase):
...
@@ -37,8 +37,29 @@ class PackOptimizationTest(tf.test.TestCase):
attention_mask
=
tf
.
ones
((
2
,
4
,
4
),
dtype
=
tf
.
float32
)
attention_mask
=
tf
.
ones
((
2
,
4
,
4
),
dtype
=
tf
.
float32
)
transformer
=
pack_optimization
.
StridedTransformerEncoderBlock
(
transformer
=
pack_optimization
.
StridedTransformerEncoderBlock
(
num_attention_heads
=
2
,
inner_dim
=
4
,
inner_activation
=
"relu"
)
num_attention_heads
=
2
,
inner_dim
=
4
,
inner_activation
=
"relu"
)
_
=
transformer
([
inputs
,
attention_mask
],
outputs
=
transformer
([
inputs
,
attention_mask
],
stride
=
tf
.
constant
(
2
,
dtype
=
tf
.
int32
))
stride
=
tf
.
constant
(
2
,
dtype
=
tf
.
int32
))
self
.
assertEqual
(
outputs
.
shape
,
(
2
,
2
,
8
))
def
test_strided_rezero_transformer
(
self
):
inputs
=
tf
.
zeros
((
2
,
4
,
8
),
dtype
=
tf
.
float32
)
attention_mask
=
tf
.
ones
((
2
,
4
,
4
),
dtype
=
tf
.
float32
)
transformer
=
pack_optimization
.
StridedReZeroTransformer
(
num_attention_heads
=
2
,
inner_dim
=
4
,
inner_activation
=
"relu"
)
outputs
=
transformer
([
inputs
,
attention_mask
],
stride
=
tf
.
constant
(
2
,
dtype
=
tf
.
int32
))
self
.
assertEqual
(
outputs
.
shape
,
(
2
,
2
,
8
))
def
test_strided_scaffold
(
self
):
inputs
=
tf
.
zeros
((
2
,
4
,
8
),
dtype
=
tf
.
float32
)
attention_mask
=
tf
.
ones
((
2
,
4
,
4
),
dtype
=
tf
.
float32
)
test_layer
=
pack_optimization
.
StridedTransformerScaffold
(
num_attention_heads
=
2
,
inner_dim
=
128
,
inner_activation
=
"relu"
)
outputs
=
test_layer
([
inputs
,
attention_mask
],
stride
=
tf
.
constant
(
2
,
dtype
=
tf
.
int32
))
self
.
assertEqual
(
outputs
.
shape
,
(
2
,
2
,
8
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment