Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
b708fd68
Commit
b708fd68
authored
Jun 19, 2020
by
A. Unique TensorFlower
Browse files
Add ELECTRA TF 2.x pretrainer.
Contributed by mickeystroller PiperOrigin-RevId: 317411747
parent
819c52f0
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
480 additions
and
6 deletions
+480
-6
official/nlp/modeling/models/electra_pretrainer.py
official/nlp/modeling/models/electra_pretrainer.py
+307
-0
official/nlp/modeling/models/electra_pretrainer_test.py
official/nlp/modeling/models/electra_pretrainer_test.py
+156
-0
official/nlp/modeling/networks/transformer_encoder.py
official/nlp/modeling/networks/transformer_encoder.py
+17
-6
No files found.
official/nlp/modeling/models/electra_pretrainer.py
0 → 100644
View file @
b708fd68
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trainer network for ELECTRA models."""
# pylint: disable=g-classes-have-attributes
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
copy
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.nlp.modeling
import
layers
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
class
ElectraPretrainer
(
tf
.
keras
.
Model
):
"""ELECTRA network training model.
This is an implementation of the network structure described in "ELECTRA:
Pre-training Text Encoders as Discriminators Rather Than Generators" (
https://arxiv.org/abs/2003.10555).
The ElectraPretrainer allows a user to pass in two transformer models, one for
generator, the other for discriminator, and instantiates the masked language
model (at generator side) and classification networks (at discriminator side)
that are used to create the training objectives.
Arguments:
generator_network: A transformer network for generator, this network should
output a sequence output and an optional classification output.
discriminator_network: A transformer network for discriminator, this network
should output a sequence output
vocab_size: Size of generator output vocabulary
num_classes: Number of classes to predict from the classification network
for the generator network (not used now)
sequence_length: Input sequence length
last_hidden_dim: Last hidden dim of generator transformer output
num_token_predictions: Number of tokens to predict from the masked LM.
mlm_activation: The activation (if any) to use in the masked LM and
classification networks. If None, no activation will be used.
mlm_initializer: The initializer (if any) to use in the masked LM and
classification networks. Defaults to a Glorot uniform initializer.
output_type: The output style for this network. Can be either 'logits' or
'predictions'.
disallow_correct: Whether to disallow the generator to generate the exact
same token in the original sentence
"""
def
__init__
(
self
,
generator_network
,
discriminator_network
,
vocab_size
,
num_classes
,
sequence_length
,
last_hidden_dim
,
num_token_predictions
,
mlm_activation
=
None
,
mlm_initializer
=
'glorot_uniform'
,
output_type
=
'logits'
,
disallow_correct
=
False
,
**
kwargs
):
super
(
ElectraPretrainer
,
self
).
__init__
()
self
.
_config
=
{
'generator_network'
:
generator_network
,
'discriminator_network'
:
discriminator_network
,
'vocab_size'
:
vocab_size
,
'num_classes'
:
num_classes
,
'sequence_length'
:
sequence_length
,
'last_hidden_dim'
:
last_hidden_dim
,
'num_token_predictions'
:
num_token_predictions
,
'mlm_activation'
:
mlm_activation
,
'mlm_initializer'
:
mlm_initializer
,
'output_type'
:
output_type
,
'disallow_correct'
:
disallow_correct
,
}
for
k
,
v
in
kwargs
.
items
():
self
.
_config
[
k
]
=
v
self
.
generator_network
=
generator_network
self
.
discriminator_network
=
discriminator_network
self
.
vocab_size
=
vocab_size
self
.
num_classes
=
num_classes
self
.
sequence_length
=
sequence_length
self
.
last_hidden_dim
=
last_hidden_dim
self
.
num_token_predictions
=
num_token_predictions
self
.
mlm_activation
=
mlm_activation
self
.
mlm_initializer
=
mlm_initializer
self
.
output_type
=
output_type
self
.
disallow_correct
=
disallow_correct
self
.
masked_lm
=
layers
.
MaskedLM
(
embedding_table
=
generator_network
.
get_embedding_table
(),
activation
=
mlm_activation
,
initializer
=
mlm_initializer
,
output
=
output_type
,
name
=
'generator_masked_lm'
)
self
.
classification
=
layers
.
ClassificationHead
(
inner_dim
=
last_hidden_dim
,
num_classes
=
num_classes
,
initializer
=
mlm_initializer
,
name
=
'generator_classification_head'
)
self
.
discriminator_head
=
tf
.
keras
.
layers
.
Dense
(
units
=
1
,
kernel_initializer
=
mlm_initializer
)
def
call
(
self
,
inputs
):
input_word_ids
=
inputs
[
'input_word_ids'
]
input_mask
=
inputs
[
'input_mask'
]
input_type_ids
=
inputs
[
'input_type_ids'
]
masked_lm_positions
=
inputs
[
'masked_lm_positions'
]
### Generator ###
sequence_output
,
cls_output
=
self
.
generator_network
(
[
input_word_ids
,
input_mask
,
input_type_ids
])
# The generator encoder network may get outputs from all layers.
if
isinstance
(
sequence_output
,
list
):
sequence_output
=
sequence_output
[
-
1
]
if
isinstance
(
cls_output
,
list
):
cls_output
=
cls_output
[
-
1
]
lm_outputs
=
self
.
masked_lm
(
sequence_output
,
masked_lm_positions
)
sentence_outputs
=
self
.
classification
(
sequence_output
)
### Sampling from generator ###
fake_data
=
self
.
_get_fake_data
(
inputs
,
lm_outputs
,
duplicate
=
True
)
### Discriminator ###
disc_input
=
fake_data
[
'inputs'
]
disc_label
=
fake_data
[
'is_fake_tokens'
]
disc_sequence_output
,
_
=
self
.
discriminator_network
([
disc_input
[
'input_word_ids'
],
disc_input
[
'input_mask'
],
disc_input
[
'input_type_ids'
]
])
# The discriminator encoder network may get outputs from all layers.
if
isinstance
(
disc_sequence_output
,
list
):
disc_sequence_output
=
disc_sequence_output
[
-
1
]
disc_logits
=
self
.
discriminator_head
(
disc_sequence_output
)
disc_logits
=
tf
.
squeeze
(
disc_logits
,
axis
=-
1
)
return
lm_outputs
,
sentence_outputs
,
disc_logits
,
disc_label
def
_get_fake_data
(
self
,
inputs
,
mlm_logits
,
duplicate
=
True
):
"""Generate corrupted data for discriminator.
Args:
inputs: A dict of all inputs, same as the input of call() function
mlm_logits: The generator's output logits
duplicate: Whether to copy the original inputs dict during modifications
Returns:
A dict of generated fake data
"""
inputs
=
unmask
(
inputs
,
duplicate
)
if
self
.
disallow_correct
:
disallow
=
tf
.
one_hot
(
inputs
[
'masked_lm_ids'
],
depth
=
self
.
vocab_size
,
dtype
=
tf
.
float32
)
else
:
disallow
=
None
sampled_tokens
=
tf
.
stop_gradient
(
sample_from_softmax
(
mlm_logits
,
disallow
=
disallow
))
sampled_tokids
=
tf
.
argmax
(
sampled_tokens
,
-
1
,
output_type
=
tf
.
int32
)
updated_input_ids
,
masked
=
scatter_update
(
inputs
[
'input_word_ids'
],
sampled_tokids
,
inputs
[
'masked_lm_positions'
])
labels
=
masked
*
(
1
-
tf
.
cast
(
tf
.
equal
(
updated_input_ids
,
inputs
[
'input_word_ids'
]),
tf
.
int32
))
updated_inputs
=
get_updated_inputs
(
inputs
,
duplicate
,
input_word_ids
=
updated_input_ids
)
return
{
'inputs'
:
updated_inputs
,
'is_fake_tokens'
:
labels
,
'sampled_tokens'
:
sampled_tokens
}
def
get_config
(
self
):
return
self
.
_config
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
def
scatter_update
(
sequence
,
updates
,
positions
):
"""Scatter-update a sequence.
Args:
sequence: A [batch_size, seq_len] or [batch_size, seq_len, depth] tensor
updates: A tensor of size batch_size*seq_len(*depth)
positions: A [batch_size, n_positions] tensor
Returns:
updated_sequence: A [batch_size, seq_len] or [batch_size, seq_len, depth]
tensor of "sequence" with elements at "positions" replaced by the values
at "updates". Updates to index 0 are ignored. If there are duplicated
positions the update is only applied once.
updates_mask: A [batch_size, seq_len] mask tensor of which inputs were
updated.
"""
shape
=
tf_utils
.
get_shape_list
(
sequence
,
expected_rank
=
[
2
,
3
])
depth_dimension
=
(
len
(
shape
)
==
3
)
if
depth_dimension
:
batch_size
,
seq_len
,
depth
=
shape
else
:
batch_size
,
seq_len
=
shape
depth
=
1
sequence
=
tf
.
expand_dims
(
sequence
,
-
1
)
n_positions
=
tf_utils
.
get_shape_list
(
positions
)[
1
]
shift
=
tf
.
expand_dims
(
seq_len
*
tf
.
range
(
batch_size
),
-
1
)
flat_positions
=
tf
.
reshape
(
positions
+
shift
,
[
-
1
,
1
])
flat_updates
=
tf
.
reshape
(
updates
,
[
-
1
,
depth
])
updates
=
tf
.
scatter_nd
(
flat_positions
,
flat_updates
,
[
batch_size
*
seq_len
,
depth
])
updates
=
tf
.
reshape
(
updates
,
[
batch_size
,
seq_len
,
depth
])
flat_updates_mask
=
tf
.
ones
([
batch_size
*
n_positions
],
tf
.
int32
)
updates_mask
=
tf
.
scatter_nd
(
flat_positions
,
flat_updates_mask
,
[
batch_size
*
seq_len
])
updates_mask
=
tf
.
reshape
(
updates_mask
,
[
batch_size
,
seq_len
])
not_first_token
=
tf
.
concat
([
tf
.
zeros
((
batch_size
,
1
),
tf
.
int32
),
tf
.
ones
((
batch_size
,
seq_len
-
1
),
tf
.
int32
)
],
-
1
)
updates_mask
*=
not_first_token
updates_mask_3d
=
tf
.
expand_dims
(
updates_mask
,
-
1
)
# account for duplicate positions
if
sequence
.
dtype
==
tf
.
float32
:
updates_mask_3d
=
tf
.
cast
(
updates_mask_3d
,
tf
.
float32
)
updates
/=
tf
.
maximum
(
1.0
,
updates_mask_3d
)
else
:
assert
sequence
.
dtype
==
tf
.
int32
updates
=
tf
.
math
.
floordiv
(
updates
,
tf
.
maximum
(
1
,
updates_mask_3d
))
updates_mask
=
tf
.
minimum
(
updates_mask
,
1
)
updates_mask_3d
=
tf
.
minimum
(
updates_mask_3d
,
1
)
updated_sequence
=
(((
1
-
updates_mask_3d
)
*
sequence
)
+
(
updates_mask_3d
*
updates
))
if
not
depth_dimension
:
updated_sequence
=
tf
.
squeeze
(
updated_sequence
,
-
1
)
return
updated_sequence
,
updates_mask
def
sample_from_softmax
(
logits
,
disallow
=
None
):
"""Implement softmax sampling using gumbel softmax trick.
Args:
logits: A [batch_size, num_token_predictions, vocab_size] tensor indicating
the generator output logits for each masked position.
disallow: If `None`, we directly sample tokens from the logits. Otherwise,
this is a tensor of size [batch_size, num_token_predictions, vocab_size]
indicating the true word id in each masked position.
Returns:
sampled_tokens: A [batch_size, num_token_predictions, vocab_size] one hot
tensor indicating the sampled word id in each masked position.
"""
if
disallow
is
not
None
:
logits
-=
1000.0
*
disallow
uniform_noise
=
tf
.
random
.
uniform
(
tf_utils
.
get_shape_list
(
logits
),
minval
=
0
,
maxval
=
1
)
gumbel_noise
=
-
tf
.
math
.
log
(
-
tf
.
math
.
log
(
uniform_noise
+
1e-9
)
+
1e-9
)
# Here we essentially follow the original paper and use temperature 1.0 for
# generator output logits.
sampled_tokens
=
tf
.
one_hot
(
tf
.
argmax
(
tf
.
nn
.
softmax
(
logits
+
gumbel_noise
),
-
1
,
output_type
=
tf
.
int32
),
logits
.
shape
[
-
1
])
return
sampled_tokens
def
unmask
(
inputs
,
duplicate
):
unmasked_input_word_ids
,
_
=
scatter_update
(
inputs
[
'input_word_ids'
],
inputs
[
'masked_lm_ids'
],
inputs
[
'masked_lm_positions'
])
return
get_updated_inputs
(
inputs
,
duplicate
,
input_word_ids
=
unmasked_input_word_ids
)
def
get_updated_inputs
(
inputs
,
duplicate
,
**
kwargs
):
if
duplicate
:
new_inputs
=
copy
.
copy
(
inputs
)
else
:
new_inputs
=
inputs
for
k
,
v
in
kwargs
.
items
():
new_inputs
[
k
]
=
v
return
new_inputs
official/nlp/modeling/models/electra_pretrainer_test.py
0 → 100644
View file @
b708fd68
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ELECTRA pre trainer network."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
from
tensorflow.python.keras
import
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
from
official.nlp.modeling
import
networks
from
official.nlp.modeling.models
import
electra_pretrainer
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@
keras_parameterized
.
run_all_keras_modes
class
ElectraPretrainerTest
(
keras_parameterized
.
TestCase
):
def
test_electra_pretrainer
(
self
):
"""Validate that the Keras object can be created."""
# Build a transformer network to use within the ELECTRA trainer.
vocab_size
=
100
sequence_length
=
512
test_generator_network
=
networks
.
TransformerEncoder
(
vocab_size
=
vocab_size
,
num_layers
=
2
,
sequence_length
=
sequence_length
)
test_discriminator_network
=
networks
.
TransformerEncoder
(
vocab_size
=
vocab_size
,
num_layers
=
2
,
sequence_length
=
sequence_length
)
# Create a ELECTRA trainer with the created network.
num_classes
=
3
num_token_predictions
=
2
eletrca_trainer_model
=
electra_pretrainer
.
ElectraPretrainer
(
generator_network
=
test_generator_network
,
discriminator_network
=
test_discriminator_network
,
vocab_size
=
vocab_size
,
num_classes
=
num_classes
,
sequence_length
=
sequence_length
,
last_hidden_dim
=
768
,
num_token_predictions
=
num_token_predictions
,
disallow_correct
=
True
)
# Create a set of 2-dimensional inputs (the first dimension is implicit).
word_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
mask
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
type_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
lm_positions
=
tf
.
keras
.
Input
(
shape
=
(
num_token_predictions
,),
dtype
=
tf
.
int32
)
lm_ids
=
tf
.
keras
.
Input
(
shape
=
(
num_token_predictions
,),
dtype
=
tf
.
int32
)
inputs
=
{
'input_word_ids'
:
word_ids
,
'input_mask'
:
mask
,
'input_type_ids'
:
type_ids
,
'masked_lm_positions'
:
lm_positions
,
'masked_lm_ids'
:
lm_ids
}
# Invoke the trainer model on the inputs. This causes the layer to be built.
lm_outs
,
cls_outs
,
disc_logits
,
disc_label
=
eletrca_trainer_model
(
inputs
)
# Validate that the outputs are of the expected shape.
expected_lm_shape
=
[
None
,
num_token_predictions
,
vocab_size
]
expected_classification_shape
=
[
None
,
num_classes
]
expected_disc_logits_shape
=
[
None
,
sequence_length
]
expected_disc_label_shape
=
[
None
,
sequence_length
]
self
.
assertAllEqual
(
expected_lm_shape
,
lm_outs
.
shape
.
as_list
())
self
.
assertAllEqual
(
expected_classification_shape
,
cls_outs
.
shape
.
as_list
())
self
.
assertAllEqual
(
expected_disc_logits_shape
,
disc_logits
.
shape
.
as_list
())
self
.
assertAllEqual
(
expected_disc_label_shape
,
disc_label
.
shape
.
as_list
())
def
test_electra_trainer_tensor_call
(
self
):
"""Validate that the Keras object can be invoked."""
# Build a transformer network to use within the ELECTRA trainer. (Here, we
# use a short sequence_length for convenience.)
test_generator_network
=
networks
.
TransformerEncoder
(
vocab_size
=
100
,
num_layers
=
4
,
sequence_length
=
3
)
test_discriminator_network
=
networks
.
TransformerEncoder
(
vocab_size
=
100
,
num_layers
=
4
,
sequence_length
=
3
)
# Create a ELECTRA trainer with the created network.
eletrca_trainer_model
=
electra_pretrainer
.
ElectraPretrainer
(
generator_network
=
test_generator_network
,
discriminator_network
=
test_discriminator_network
,
vocab_size
=
100
,
num_classes
=
2
,
sequence_length
=
3
,
last_hidden_dim
=
768
,
num_token_predictions
=
2
)
# Create a set of 2-dimensional data tensors to feed into the model.
word_ids
=
tf
.
constant
([[
1
,
1
,
1
],
[
2
,
2
,
2
]],
dtype
=
tf
.
int32
)
mask
=
tf
.
constant
([[
1
,
1
,
1
],
[
1
,
0
,
0
]],
dtype
=
tf
.
int32
)
type_ids
=
tf
.
constant
([[
1
,
1
,
1
],
[
2
,
2
,
2
]],
dtype
=
tf
.
int32
)
lm_positions
=
tf
.
constant
([[
0
,
1
],
[
0
,
2
]],
dtype
=
tf
.
int32
)
lm_ids
=
tf
.
constant
([[
10
,
20
],
[
20
,
30
]],
dtype
=
tf
.
int32
)
inputs
=
{
'input_word_ids'
:
word_ids
,
'input_mask'
:
mask
,
'input_type_ids'
:
type_ids
,
'masked_lm_positions'
:
lm_positions
,
'masked_lm_ids'
:
lm_ids
}
# Invoke the trainer model on the tensors. In Eager mode, this does the
# actual calculation. (We can't validate the outputs, since the network is
# too complex: this simply ensures we're not hitting runtime errors.)
_
,
_
,
_
,
_
=
eletrca_trainer_model
(
inputs
)
def
test_serialize_deserialize
(
self
):
"""Validate that the ELECTRA trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.)
test_generator_network
=
networks
.
TransformerEncoder
(
vocab_size
=
100
,
num_layers
=
4
,
sequence_length
=
3
)
test_discriminator_network
=
networks
.
TransformerEncoder
(
vocab_size
=
100
,
num_layers
=
4
,
sequence_length
=
3
)
# Create a ELECTRA trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.)
electra_trainer_model
=
electra_pretrainer
.
ElectraPretrainer
(
generator_network
=
test_generator_network
,
discriminator_network
=
test_discriminator_network
,
vocab_size
=
100
,
num_classes
=
2
,
sequence_length
=
3
,
last_hidden_dim
=
768
,
num_token_predictions
=
2
)
# Create another BERT trainer via serialization and deserialization.
config
=
electra_trainer_model
.
get_config
()
new_electra_trainer_model
=
electra_pretrainer
.
ElectraPretrainer
.
from_config
(
config
)
# Validate that the config can be forced to JSON.
_
=
new_electra_trainer_model
.
to_json
()
# If the serialization was successful, the new config should match the old.
self
.
assertAllEqual
(
electra_trainer_model
.
get_config
(),
new_electra_trainer_model
.
get_config
())
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/nlp/modeling/networks/transformer_encoder.py
View file @
b708fd68
...
@@ -60,7 +60,7 @@ class TransformerEncoder(tf.keras.Model):
...
@@ -60,7 +60,7 @@ class TransformerEncoder(tf.keras.Model):
initializer: The initialzer to use for all weights in this encoder.
initializer: The initialzer to use for all weights in this encoder.
return_all_encoder_outputs: Whether to output sequence embedding outputs of
return_all_encoder_outputs: Whether to output sequence embedding outputs of
all encoder transformer layers.
all encoder transformer layers.
output_range:
t
he sequence output range, [0, output_range), by slicing the
output_range:
T
he sequence output range, [0, output_range), by slicing the
target sequence of the last transformer layer. `None` means the entire
target sequence of the last transformer layer. `None` means the entire
target sequence will attend to the source sequence, which yeilds the full
target sequence will attend to the source sequence, which yeilds the full
output.
output.
...
@@ -69,6 +69,10 @@ class TransformerEncoder(tf.keras.Model):
...
@@ -69,6 +69,10 @@ class TransformerEncoder(tf.keras.Model):
two matrices in the shape of ['vocab_size', 'embedding_width'] and
two matrices in the shape of ['vocab_size', 'embedding_width'] and
['embedding_width', 'hidden_size'] ('embedding_width' is usually much
['embedding_width', 'hidden_size'] ('embedding_width' is usually much
smaller than 'hidden_size').
smaller than 'hidden_size').
embedding_layer: The word embedding layer. `None` means we will create a new
embedding layer. Otherwise, we will reuse the given embedding layer. This
parameter is originally added for ELECTRA model which needs to tie the
generator embeddings with the discriminator embeddings.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -87,6 +91,7 @@ class TransformerEncoder(tf.keras.Model):
...
@@ -87,6 +91,7 @@ class TransformerEncoder(tf.keras.Model):
return_all_encoder_outputs
=
False
,
return_all_encoder_outputs
=
False
,
output_range
=
None
,
output_range
=
None
,
embedding_width
=
None
,
embedding_width
=
None
,
embedding_layer
=
None
,
**
kwargs
):
**
kwargs
):
activation
=
tf
.
keras
.
activations
.
get
(
activation
)
activation
=
tf
.
keras
.
activations
.
get
(
activation
)
initializer
=
tf
.
keras
.
initializers
.
get
(
initializer
)
initializer
=
tf
.
keras
.
initializers
.
get
(
initializer
)
...
@@ -121,11 +126,14 @@ class TransformerEncoder(tf.keras.Model):
...
@@ -121,11 +126,14 @@ class TransformerEncoder(tf.keras.Model):
if
embedding_width
is
None
:
if
embedding_width
is
None
:
embedding_width
=
hidden_size
embedding_width
=
hidden_size
self
.
_embedding_layer
=
layers
.
OnDeviceEmbedding
(
if
embedding_layer
is
None
:
vocab_size
=
vocab_size
,
self
.
_embedding_layer
=
layers
.
OnDeviceEmbedding
(
embedding_width
=
embedding_width
,
vocab_size
=
vocab_size
,
initializer
=
initializer
,
embedding_width
=
embedding_width
,
name
=
'word_embeddings'
)
initializer
=
initializer
,
name
=
'word_embeddings'
)
else
:
self
.
_embedding_layer
=
embedding_layer
word_embeddings
=
self
.
_embedding_layer
(
word_ids
)
word_embeddings
=
self
.
_embedding_layer
(
word_ids
)
# Always uses dynamic slicing for simplicity.
# Always uses dynamic slicing for simplicity.
...
@@ -209,6 +217,9 @@ class TransformerEncoder(tf.keras.Model):
...
@@ -209,6 +217,9 @@ class TransformerEncoder(tf.keras.Model):
def
get_embedding_table
(
self
):
def
get_embedding_table
(
self
):
return
self
.
_embedding_layer
.
embeddings
return
self
.
_embedding_layer
.
embeddings
def
get_embedding_layer
(
self
):
return
self
.
_embedding_layer
def
get_config
(
self
):
def
get_config
(
self
):
return
self
.
_config_dict
return
self
.
_config_dict
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment