Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
fd270433
Commit
fd270433
authored
Dec 07, 2020
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 346088379
parent
8ffa9448
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
3 deletions
+18
-3
official/nlp/modeling/networks/packed_sequence_embedding.py
official/nlp/modeling/networks/packed_sequence_embedding.py
+15
-3
official/nlp/modeling/networks/packed_sequence_embedding_test.py
...l/nlp/modeling/networks/packed_sequence_embedding_test.py
+3
-0
No files found.
official/nlp/modeling/networks/packed_sequence_embedding.py
View file @
fd270433
...
...
@@ -35,7 +35,8 @@ class PackedSequenceEmbedding(tf.keras.Model):
Arguments:
vocab_size: The size of the token vocabulary.
type_vocab_size: The size of the type vocabulary.
hidden_size: The hidden size for this encoder.
embedding_width: Width of token embeddings.
hidden_size: The output size for this encoder.
max_seq_length: The maximum sequence length for this encoder.
initializer: The initializer for the embedding portion of this encoder.
dropout_rate: The dropout rate to apply before the encoding layers.
...
...
@@ -52,6 +53,7 @@ class PackedSequenceEmbedding(tf.keras.Model):
def
__init__
(
self
,
vocab_size
,
type_vocab_size
,
embedding_width
,
hidden_size
,
max_seq_length
,
initializer
,
...
...
@@ -63,6 +65,7 @@ class PackedSequenceEmbedding(tf.keras.Model):
config_dict
=
{
'vocab_size'
:
vocab_size
,
'type_vocab_size'
:
type_vocab_size
,
'embedding_width'
:
embedding_width
,
'hidden_size'
:
hidden_size
,
'max_seq_length'
:
max_seq_length
,
'initializer'
:
tf
.
keras
.
initializers
.
serialize
(
initializer
),
...
...
@@ -96,7 +99,7 @@ class PackedSequenceEmbedding(tf.keras.Model):
embedding_layer
=
layers
.
OnDeviceEmbedding
(
vocab_size
=
vocab_size
,
embedding_width
=
hidden_size
,
embedding_width
=
embedding_width
,
initializer
=
initializer
,
name
=
'word_embeddings'
)
word_embeddings
=
embedding_layer
(
word_ids
)
...
...
@@ -113,7 +116,7 @@ class PackedSequenceEmbedding(tf.keras.Model):
type_embeddings
=
(
layers
.
OnDeviceEmbedding
(
vocab_size
=
type_vocab_size
,
embedding_width
=
hidden_size
,
embedding_width
=
embedding_width
,
initializer
=
initializer
,
use_one_hot
=
True
,
name
=
'type_embeddings'
)(
type_ids
))
...
...
@@ -127,6 +130,15 @@ class PackedSequenceEmbedding(tf.keras.Model):
rate
=
dropout_rate
,
dtype
=
tf
.
float32
)(
embeddings
)
if
embedding_width
!=
hidden_size
:
embeddings
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
'...x,xy->...y'
,
output_shape
=
hidden_size
,
bias_axes
=
None
,
kernel_initializer
=
initializer
,
name
=
'embedding_projection'
)(
embeddings
)
attention_mask
=
layers
.
SelfAttentionMask
()([
embeddings
,
mask
])
if
sub_seq_mask
is
not
None
:
attention_mask
=
tf
.
keras
.
layers
.
Lambda
(
...
...
official/nlp/modeling/networks/packed_sequence_embedding_test.py
View file @
fd270433
...
...
@@ -45,10 +45,12 @@ class PackedSequenceEmbeddingTest(tf.test.TestCase, parameterized.TestCase):
vocab_size
=
100
max_position_embeddings
=
32
type_vocab_size
=
2
embedding_width
=
16
hidden_size
=
32
embedding_cfg
=
dict
(
vocab_size
=
vocab_size
,
type_vocab_size
=
2
,
embedding_width
=
embedding_width
,
hidden_size
=
hidden_size
,
max_seq_length
=
max_position_embeddings
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
),
...
...
@@ -103,6 +105,7 @@ class PackedSequenceEmbeddingTest(tf.test.TestCase, parameterized.TestCase):
embedding_cfg
=
dict
(
vocab_size
=
100
,
type_vocab_size
=
2
,
embedding_width
=
64
,
hidden_size
=
64
,
max_seq_length
=
32
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment