Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
59b5985e
Commit
59b5985e
authored
Jun 07, 2022
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Jun 07, 2022
Browse files
Internal change
PiperOrigin-RevId: 453490507
parent
ee708859
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
221 additions
and
14 deletions
+221
-14
official/nlp/modeling/layers/__init__.py
official/nlp/modeling/layers/__init__.py
+2
-0
official/nlp/modeling/layers/pack_optimization.py
official/nlp/modeling/layers/pack_optimization.py
+144
-0
official/nlp/modeling/layers/pack_optimization_test.py
official/nlp/modeling/layers/pack_optimization_test.py
+45
-0
official/nlp/modeling/layers/self_attention_mask.py
official/nlp/modeling/layers/self_attention_mask.py
+30
-14
No files found.
official/nlp/modeling/layers/__init__.py
View file @
59b5985e
...
...
@@ -34,6 +34,8 @@ from official.nlp.modeling.layers.mobile_bert_layers import MobileBertMaskedLM
from
official.nlp.modeling.layers.mobile_bert_layers
import
MobileBertTransformer
from
official.nlp.modeling.layers.multi_channel_attention
import
*
from
official.nlp.modeling.layers.on_device_embedding
import
OnDeviceEmbedding
from
official.nlp.modeling.layers.pack_optimization
import
PackBertEmbeddings
from
official.nlp.modeling.layers.pack_optimization
import
StridedTransformerEncoderBlock
from
official.nlp.modeling.layers.position_embedding
import
PositionEmbedding
from
official.nlp.modeling.layers.position_embedding
import
RelativePositionBias
from
official.nlp.modeling.layers.position_embedding
import
RelativePositionEmbedding
...
...
official/nlp/modeling/layers/pack_optimization.py
0 → 100644
View file @
59b5985e
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pack sequence optimization on accelerators."""
from
typing
import
Dict
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.nlp.modeling.layers
import
self_attention_mask
from
official.nlp.modeling.layers
import
transformer_encoder_block
def
_packing_mask
(
segment_id
,
source_segment_id
,
dtype
=
tf
.
float32
):
"""Calculates a segment mask for attention.
Args:
segment_id: [B, T]
source_segment_id: [B, S]
dtype: data type of generated mask.
Returns:
segment_mask: [B, T, S]
"""
if
segment_id
is
None
or
source_segment_id
is
None
:
return
None
# Compute [B, T, S] = [B, T, 1] == [B, 1, S]
return
tf
.
cast
(
tf
.
equal
(
tf
.
expand_dims
(
segment_id
,
2
),
tf
.
expand_dims
(
source_segment_id
,
1
)),
dtype
=
dtype
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
class
PackBertEmbeddings
(
tf
.
keras
.
layers
.
Layer
):
"""Performs packing tricks for BERT inputs to improve TPU utilization."""
def
__init__
(
self
,
pack_sequences
:
int
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
pack_sequences
=
pack_sequences
def
call
(
self
,
input_embeddings
:
tf
.
Tensor
,
input_mask
:
tf
.
Tensor
)
->
Dict
[
str
,
tf
.
Tensor
]:
batch_size
,
seq_len
,
embedding_dim
=
tf_utils
.
get_shape_list
(
input_embeddings
,
expected_rank
=
3
)
example_ids
=
None
reduced_batch_size
=
batch_size
//
self
.
pack_sequences
packed_seq_len
=
self
.
pack_sequences
*
seq_len
packed_embeddings
=
tf
.
reshape
(
input_embeddings
,
[
reduced_batch_size
,
packed_seq_len
,
embedding_dim
])
input_mask
=
tf
.
reshape
(
input_mask
,
[
reduced_batch_size
,
packed_seq_len
])
example_ids
=
1
+
tf
.
range
(
self
.
pack_sequences
)
# Shape: [batch_size, seq_len, pack_sequences].
example_ids
=
tf
.
tile
(
example_ids
[
None
,
:,
None
],
[
reduced_batch_size
,
1
,
seq_len
])
example_ids
=
tf
.
reshape
(
example_ids
,
[
reduced_batch_size
,
packed_seq_len
])
example_ids
=
tf
.
where
(
tf
.
math
.
equal
(
input_mask
,
0
),
tf
.
zeros_like
(
example_ids
),
example_ids
)
packing_mask
=
_packing_mask
(
example_ids
,
example_ids
,
dtype
=
tf
.
bool
)
attention_mask
=
self_attention_mask
.
get_mask
(
packed_embeddings
,
input_mask
,
dtype
=
tf
.
bool
)
combined_attention_mask
=
tf
.
cast
(
tf
.
math
.
logical_and
(
attention_mask
,
packing_mask
),
tf
.
float32
)
return
dict
(
packed_embeddings
=
packed_embeddings
,
combined_attention_mask
=
combined_attention_mask
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
class
StridedTransformerEncoderBlock
(
transformer_encoder_block
.
TransformerEncoderBlock
):
"""Transformer layer for packing optimization to stride over inputs."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
if
self
.
_output_range
is
not
None
:
raise
ValueError
(
'StridedTransformerEncoderBlock does not '
'support `output_range` argument.'
)
def
call
(
self
,
inputs
,
stride
:
tf
.
Tensor
):
if
isinstance
(
inputs
,
(
list
,
tuple
)):
if
len
(
inputs
)
==
2
:
input_tensor
,
attention_mask
=
inputs
key_value
=
None
elif
len
(
inputs
)
==
3
:
input_tensor
,
key_value
,
attention_mask
=
inputs
else
:
raise
ValueError
(
'Unexpected inputs to %s with length at %d'
%
(
self
.
__class__
,
len
(
inputs
)))
else
:
input_tensor
,
key_value
,
attention_mask
=
(
inputs
,
None
,
None
)
if
self
.
_norm_first
:
source_tensor
=
input_tensor
[:,
::
stride
,
:]
input_tensor
=
self
.
_attention_layer_norm
(
input_tensor
)
if
key_value
is
not
None
:
key_value
=
self
.
_attention_layer_norm_kv
(
key_value
)
target_tensor
=
input_tensor
[:,
::
stride
,
:]
if
attention_mask
is
not
None
:
attention_mask
=
attention_mask
[:,
::
stride
,
:]
if
key_value
is
None
:
key_value
=
input_tensor
attention_output
=
self
.
_attention_layer
(
query
=
target_tensor
,
value
=
key_value
,
attention_mask
=
attention_mask
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
if
self
.
_norm_first
:
# Important to not combine `self._norm_first` and
# `self._use_query_residual` into one if clause because else is only for
# `_norm_first == False`.
if
self
.
_use_query_residual
:
attention_output
=
source_tensor
+
attention_output
else
:
if
self
.
_use_query_residual
:
attention_output
=
target_tensor
+
attention_output
attention_output
=
self
.
_attention_layer_norm
(
attention_output
)
if
self
.
_norm_first
:
source_attention_output
=
attention_output
attention_output
=
self
.
_output_layer_norm
(
attention_output
)
inner_output
=
self
.
_intermediate_dense
(
attention_output
)
inner_output
=
self
.
_intermediate_activation_layer
(
inner_output
)
inner_output
=
self
.
_inner_dropout_layer
(
inner_output
)
layer_output
=
self
.
_output_dense
(
inner_output
)
layer_output
=
self
.
_output_dropout
(
layer_output
)
if
self
.
_norm_first
:
return
source_attention_output
+
layer_output
layer_output
=
tf
.
cast
(
layer_output
,
tf
.
float32
)
return
self
.
_output_layer_norm
(
layer_output
+
attention_output
)
official/nlp/modeling/layers/pack_optimization_test.py
0 → 100644
View file @
59b5985e
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for pack_optimization."""
import
tensorflow
as
tf
from
official.nlp.modeling.layers
import
pack_optimization
class
PackOptimizationTest
(
tf
.
test
.
TestCase
):
def
test_bert_embedding_packing
(
self
):
batch_size
,
seq_len
,
embed_dim
=
2
,
4
,
8
pack_sequences
=
2
token_and_position_embed
=
tf
.
ones
((
batch_size
,
seq_len
,
embed_dim
),
dtype
=
tf
.
float32
)
input_mask
=
tf
.
ones
((
batch_size
,
seq_len
),
dtype
=
tf
.
int32
)
layer
=
pack_optimization
.
PackBertEmbeddings
(
pack_sequences
=
pack_sequences
)
outputs
=
layer
(
token_and_position_embed
,
input_mask
)
self
.
assertEqual
(
outputs
[
"packed_embeddings"
].
shape
,
(
1
,
8
,
embed_dim
))
self
.
assertEqual
(
outputs
[
"combined_attention_mask"
].
shape
,
(
1
,
8
,
8
))
def
test_strided_transformer_encoder_block
(
self
):
inputs
=
tf
.
zeros
((
2
,
4
,
8
),
dtype
=
tf
.
float32
)
attention_mask
=
tf
.
ones
((
2
,
4
,
4
),
dtype
=
tf
.
float32
)
transformer
=
pack_optimization
.
StridedTransformerEncoderBlock
(
num_attention_heads
=
2
,
inner_dim
=
4
,
inner_activation
=
"relu"
)
_
=
transformer
([
inputs
,
attention_mask
],
stride
=
tf
.
constant
(
2
,
dtype
=
tf
.
int32
))
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/nlp/modeling/layers/self_attention_mask.py
View file @
59b5985e
...
...
@@ -13,10 +13,38 @@
# limitations under the License.
"""Keras layer that creates a self-attention mask."""
from
typing
import
Optional
import
tensorflow
as
tf
def
get_mask
(
inputs
:
tf
.
Tensor
,
to_mask
:
tf
.
Tensor
,
dtype
:
Optional
[
tf
.
DType
]
=
None
)
->
tf
.
Tensor
:
"""Gets a 3D self-attention mask.
Args:
inputs: from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length,
...].
to_mask: int32 Tensor of shape [batch_size, to_seq_length].
dtype: the output Tensor dtype.
Returns:
float Tensor of shape [batch_size, from_seq_length, to_seq_length].
"""
from_shape
=
tf
.
shape
(
inputs
)
batch_size
=
from_shape
[
0
]
from_seq_length
=
from_shape
[
1
]
dtype
=
inputs
.
dtype
if
dtype
is
None
else
dtype
to_shape
=
tf
.
shape
(
to_mask
)
to_seq_length
=
to_shape
[
1
]
to_mask
=
tf
.
cast
(
tf
.
reshape
(
to_mask
,
[
batch_size
,
1
,
to_seq_length
]),
dtype
=
dtype
)
return
tf
.
broadcast_to
(
to_mask
,
[
batch_size
,
from_seq_length
,
to_seq_length
])
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
class
SelfAttentionMask
(
tf
.
keras
.
layers
.
Layer
):
"""Create 3D attention mask from a 2D tensor mask.
...
...
@@ -33,16 +61,4 @@ class SelfAttentionMask(tf.keras.layers.Layer):
if
isinstance
(
inputs
,
list
)
and
to_mask
is
None
:
to_mask
=
inputs
[
1
]
inputs
=
inputs
[
0
]
from_shape
=
tf
.
shape
(
inputs
)
batch_size
=
from_shape
[
0
]
from_seq_length
=
from_shape
[
1
]
to_shape
=
tf
.
shape
(
to_mask
)
to_seq_length
=
to_shape
[
1
]
to_mask
=
tf
.
cast
(
tf
.
reshape
(
to_mask
,
[
batch_size
,
1
,
to_seq_length
]),
dtype
=
inputs
.
dtype
)
return
tf
.
broadcast_to
(
to_mask
,
[
batch_size
,
from_seq_length
,
to_seq_length
])
return
get_mask
(
inputs
,
to_mask
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment