Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
d544a3d8
Commit
d544a3d8
authored
Aug 27, 2021
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 393343072
parent
c523defa
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
740 additions
and
0 deletions
+740
-0
official/nlp/projects/teams/teams.py
official/nlp/projects/teams/teams.py
+94
-0
official/nlp/projects/teams/teams_pretrainer.py
official/nlp/projects/teams/teams_pretrainer.py
+462
-0
official/nlp/projects/teams/teams_pretrainer_test.py
official/nlp/projects/teams/teams_pretrainer_test.py
+184
-0
No files found.
official/nlp/projects/teams/teams.py
0 → 100644
View file @
d544a3d8
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TEAMS model configurations and instantiation methods."""
import
dataclasses
import
gin
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.modeling.hyperparams
import
base_config
from
official.nlp.configs
import
encoders
from
official.nlp.modeling
import
networks
@
dataclasses
.
dataclass
class
TeamsPretrainerConfig
(
base_config
.
Config
):
"""Teams pretrainer configuration."""
num_masked_tokens
:
int
=
76
sequence_length
:
int
=
512
num_classes
:
int
=
2
discriminator_loss_weight
:
float
=
50.0
# Whether share embedding network between generator and discriminator.
tie_embeddings
:
bool
=
True
# Number of bottom layers shared between generator and discriminator.
num_shared_generator_hidden_layers
:
int
=
3
# Number of bottom layers shared between different discriminator tasks.
num_discriminator_task_agnostic_layers
:
int
=
11
disallow_correct
:
bool
=
False
generator
:
encoders
.
BertEncoderConfig
=
encoders
.
BertEncoderConfig
()
discriminator
:
encoders
.
BertEncoderConfig
=
encoders
.
BertEncoderConfig
()
@
gin
.
configurable
def
get_encoder
(
bert_config
,
encoder_scaffold_cls
,
embedding_inst
=
None
,
hidden_inst
=
None
):
"""Gets a 'EncoderScaffold' object.
Args:
bert_config: A 'modeling.BertConfig'.
encoder_scaffold_cls: An EncoderScaffold class.
embedding_inst: Embedding instance.
hidden_inst: List of hidden layer instances.
Returns:
A encoder object.
"""
if
embedding_inst
is
not
None
:
# TODO(hongkuny): evaluate if it is better to put cfg definition in gin.
embedding_cfg
=
dict
(
vocab_size
=
bert_config
.
vocab_size
,
type_vocab_size
=
bert_config
.
type_vocab_size
,
hidden_size
=
bert_config
.
hidden_size
,
embedding_width
=
bert_config
.
embedding_size
,
max_seq_length
=
bert_config
.
max_position_embeddings
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
bert_config
.
initializer_range
),
dropout_rate
=
bert_config
.
hidden_dropout_prob
,
)
embedding_inst
=
networks
.
PackedSequenceEmbedding
(
**
embedding_cfg
)
hidden_cfg
=
dict
(
num_attention_heads
=
bert_config
.
num_attention_heads
,
intermediate_size
=
bert_config
.
intermediate_size
,
intermediate_activation
=
tf_utils
.
get_activation
(
bert_config
.
hidden_act
),
dropout_rate
=
bert_config
.
hidden_dropout_prob
,
attention_dropout_rate
=
bert_config
.
attention_probs_dropout_prob
,
kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
bert_config
.
initializer_range
),
)
kwargs
=
dict
(
embedding_cfg
=
embedding_cfg
,
embedding_cls
=
embedding_inst
,
hidden_cls
=
hidden_inst
,
hidden_cfg
=
hidden_cfg
,
num_hidden_instances
=
bert_config
.
num_hidden_layers
,
pooled_output_dim
=
bert_config
.
hidden_size
,
pooler_layer_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
bert_config
.
initializer_range
))
# Relies on gin configuration to define the Transformer encoder arguments.
return
encoder_scaffold_cls
(
**
kwargs
)
official/nlp/projects/teams/teams_pretrainer.py
0 → 100644
View file @
d544a3d8
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Trainer network for ELECTRA models."""
# pylint: disable=g-classes-have-attributes
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.nlp.modeling
import
layers
from
official.nlp.modeling
import
models
class
ReplacedTokenDetectionHead
(
tf
.
keras
.
layers
.
Layer
):
"""Replaced token detection discriminator head.
Arguments:
encoder_cfg: Encoder config, used to create hidden layers and head.
num_task_agnostic_layers: Number of task agnostic layers in the
discriminator.
output: The output style for this network. Can be either 'logits' or
'predictions'.
"""
def
__init__
(
self
,
encoder_cfg
,
num_task_agnostic_layers
,
output
=
'logits'
,
name
=
'rtd'
,
**
kwargs
):
super
(
ReplacedTokenDetectionHead
,
self
).
__init__
(
name
=
name
,
**
kwargs
)
self
.
num_task_agnostic_layers
=
num_task_agnostic_layers
self
.
hidden_size
=
encoder_cfg
[
'embedding_cfg'
][
'hidden_size'
]
self
.
num_hidden_instances
=
encoder_cfg
[
'num_hidden_instances'
]
self
.
hidden_cfg
=
encoder_cfg
[
'hidden_cfg'
]
self
.
activation
=
self
.
hidden_cfg
[
'intermediate_activation'
]
self
.
initializer
=
self
.
hidden_cfg
[
'kernel_initializer'
]
if
output
not
in
(
'predictions'
,
'logits'
):
raise
ValueError
(
(
'Unknown `output` value "%s". `output` can be either "logits" or '
'"predictions"'
)
%
output
)
self
.
_output_type
=
output
def
build
(
self
,
input_shape
):
self
.
hidden_layers
=
[]
for
i
in
range
(
self
.
num_task_agnostic_layers
,
self
.
num_hidden_instances
):
self
.
hidden_layers
.
append
(
layers
.
Transformer
(
num_attention_heads
=
self
.
hidden_cfg
[
'num_attention_heads'
],
intermediate_size
=
self
.
hidden_cfg
[
'intermediate_size'
],
intermediate_activation
=
self
.
activation
,
dropout_rate
=
self
.
hidden_cfg
[
'dropout_rate'
],
attention_dropout_rate
=
self
.
hidden_cfg
[
'attention_dropout_rate'
],
kernel_initializer
=
self
.
initializer
,
name
=
'transformer/layer_%d_rtd'
%
i
))
self
.
dense
=
tf
.
keras
.
layers
.
Dense
(
self
.
hidden_size
,
activation
=
self
.
activation
,
kernel_initializer
=
self
.
initializer
,
name
=
'transform/rtd_dense'
)
self
.
rtd_head
=
tf
.
keras
.
layers
.
Dense
(
units
=
1
,
kernel_initializer
=
self
.
initializer
,
name
=
'transform/rtd_head'
)
def
call
(
self
,
sequence_data
,
input_mask
):
"""Compute inner-products of hidden vectors with sampled element embeddings.
Args:
sequence_data: A [batch_size, seq_length, num_hidden] tensor.
input_mask: A [batch_size, seq_length] binary mask to separate the input
from the padding.
Returns:
A [batch_size, seq_length] tensor.
"""
attention_mask
=
layers
.
SelfAttentionMask
()([
sequence_data
,
input_mask
])
data
=
sequence_data
for
hidden_layer
in
self
.
hidden_layers
:
data
=
hidden_layer
([
sequence_data
,
attention_mask
])
rtd_logits
=
self
.
rtd_head
(
self
.
dense
(
data
))
return
tf
.
squeeze
(
rtd_logits
,
axis
=-
1
)
class
MultiWordSelectionHead
(
tf
.
keras
.
layers
.
Layer
):
"""Multi-word selection discriminator head.
Arguments:
embedding_table: The embedding table.
activation: The activation, if any, for the dense layer.
initializer: The intializer for the dense layer. Defaults to a Glorot
uniform initializer.
output: The output style for this network. Can be either 'logits' or
'predictions'.
"""
def
__init__
(
self
,
embedding_table
,
activation
=
None
,
initializer
=
'glorot_uniform'
,
output
=
'logits'
,
name
=
'mws'
,
**
kwargs
):
super
(
MultiWordSelectionHead
,
self
).
__init__
(
name
=
name
,
**
kwargs
)
self
.
embedding_table
=
embedding_table
self
.
activation
=
activation
self
.
initializer
=
tf
.
keras
.
initializers
.
get
(
initializer
)
if
output
not
in
(
'predictions'
,
'logits'
):
raise
ValueError
(
(
'Unknown `output` value "%s". `output` can be either "logits" or '
'"predictions"'
)
%
output
)
self
.
_output_type
=
output
def
build
(
self
,
input_shape
):
self
.
_vocab_size
,
self
.
embed_size
=
self
.
embedding_table
.
shape
self
.
dense
=
tf
.
keras
.
layers
.
Dense
(
self
.
embed_size
,
activation
=
self
.
activation
,
kernel_initializer
=
self
.
initializer
,
name
=
'transform/mws_dense'
)
self
.
layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
axis
=-
1
,
epsilon
=
1e-12
,
name
=
'transform/mws_layernorm'
)
super
(
MultiWordSelectionHead
,
self
).
build
(
input_shape
)
def
call
(
self
,
sequence_data
,
masked_positions
,
candidate_sets
):
"""Compute inner-products of hidden vectors with sampled element embeddings.
Args:
sequence_data: A [batch_size, seq_length, num_hidden] tensor.
masked_positions: A [batch_size, num_prediction] tensor.
candidate_sets: A [batch_size, num_prediction, k] tensor.
Returns:
A [batch_size, num_prediction, k] tensor.
"""
# Gets shapes for later usage
candidate_set_shape
=
tf_utils
.
get_shape_list
(
candidate_sets
)
num_prediction
=
candidate_set_shape
[
1
]
# Gathers hidden vectors -> (batch_size, num_prediction, 1, embed_size)
masked_lm_input
=
self
.
_gather_indexes
(
sequence_data
,
masked_positions
)
lm_data
=
self
.
dense
(
masked_lm_input
)
lm_data
=
self
.
layer_norm
(
lm_data
)
lm_data
=
tf
.
expand_dims
(
tf
.
reshape
(
lm_data
,
[
-
1
,
num_prediction
,
self
.
embed_size
]),
2
)
# Gathers embeddings -> (batch_size, num_prediction, embed_size, k)
flat_candidate_sets
=
tf
.
reshape
(
candidate_sets
,
[
-
1
])
candidate_embeddings
=
tf
.
gather
(
self
.
embedding_table
,
flat_candidate_sets
)
candidate_embeddings
=
tf
.
reshape
(
candidate_embeddings
,
tf
.
concat
([
tf
.
shape
(
candidate_sets
),
[
self
.
embed_size
]],
axis
=
0
)
)
candidate_embeddings
.
set_shape
(
candidate_sets
.
shape
.
as_list
()
+
[
self
.
embed_size
])
candidate_embeddings
=
tf
.
transpose
(
candidate_embeddings
,
[
0
,
1
,
3
,
2
])
# matrix multiplication + squeeze -> (batch_size, num_prediction, k)
logits
=
tf
.
matmul
(
lm_data
,
candidate_embeddings
)
logits
=
tf
.
squeeze
(
logits
,
2
)
if
self
.
_output_type
==
'logits'
:
return
logits
return
tf
.
nn
.
log_softmax
(
logits
)
def
_gather_indexes
(
self
,
sequence_tensor
,
positions
):
"""Gathers the vectors at the specific positions.
Args:
sequence_tensor: Sequence output of shape
(`batch_size`, `seq_length`, `num_hidden`) where `num_hidden` is
number of hidden units.
positions: Positions ids of tokens in batched sequences.
Returns:
Sequence tensor of shape (batch_size * num_predictions,
num_hidden).
"""
sequence_shape
=
tf_utils
.
get_shape_list
(
sequence_tensor
,
name
=
'sequence_output_tensor'
)
batch_size
,
seq_length
,
width
=
sequence_shape
flat_offsets
=
tf
.
reshape
(
tf
.
range
(
0
,
batch_size
,
dtype
=
tf
.
int32
)
*
seq_length
,
[
-
1
,
1
])
flat_positions
=
tf
.
reshape
(
positions
+
flat_offsets
,
[
-
1
])
flat_sequence_tensor
=
tf
.
reshape
(
sequence_tensor
,
[
batch_size
*
seq_length
,
width
])
output_tensor
=
tf
.
gather
(
flat_sequence_tensor
,
flat_positions
)
return
output_tensor
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
class
TeamsPretrainer
(
tf
.
keras
.
Model
):
"""TEAMS network training model.
This is an implementation of the network structure described in "Training
ELECTRA Augmented with Multi-word Selection"
(https://arxiv.org/abs/2106.00139).
The TeamsPretrainer allows a user to pass in two transformer encoders, one
for generator, the other for discriminator (multi-word selection). The
pretrainer then instantiates the masked language model (at generator side) and
classification networks (including both multi-word selection head and replaced
token detection head) that are used to create the training objectives.
*Note* that the model is constructed by Keras Subclass API, where layers are
defined inside `__init__` and `call()` implements the computation.
Args:
generator_network: A transformer encoder for generator, this network should
output a sequence output.
discriminator_mws_network: A transformer encoder for multi-word selection
discriminator, this network should output a sequence output.
num_discriminator_task_agnostic_layers: Number of layers shared between
multi-word selection and random token detection discriminators.
vocab_size: Size of generator output vocabulary
num_classes: Number of classes to predict from the classification network
for the generator network (not used now)
mlm_activation: The activation (if any) to use in the masked LM and
classification networks. If None, no activation will be used.
mlm_initializer: The initializer (if any) to use in the masked LM and
classification networks. Defaults to a Glorot uniform initializer.
output_type: The output style for this network. Can be either `logits` or
`predictions`.
disallow_correct: Whether to disallow the generator to generate the exact
same token in the original sentence
"""
def
__init__
(
self
,
generator_network
,
discriminator_mws_network
,
num_discriminator_task_agnostic_layers
,
vocab_size
,
candidate_size
=
5
,
mlm_activation
=
None
,
mlm_initializer
=
'glorot_uniform'
,
output_type
=
'logits'
,
**
kwargs
):
super
().
__init__
()
self
.
_config
=
{
'generator_network'
:
generator_network
,
'discriminator_mws_network'
:
discriminator_mws_network
,
'num_discriminator_task_agnostic_layers'
:
num_discriminator_task_agnostic_layers
,
'vocab_size'
:
vocab_size
,
'candidate_size'
:
candidate_size
,
'mlm_activation'
:
mlm_activation
,
'mlm_initializer'
:
mlm_initializer
,
'output_type'
:
output_type
,
}
for
k
,
v
in
kwargs
.
items
():
self
.
_config
[
k
]
=
v
self
.
generator_network
=
generator_network
self
.
discriminator_mws_network
=
discriminator_mws_network
self
.
vocab_size
=
vocab_size
self
.
candidate_size
=
candidate_size
self
.
mlm_activation
=
mlm_activation
self
.
mlm_initializer
=
mlm_initializer
self
.
output_type
=
output_type
embedding_table
=
generator_network
.
embedding_network
.
get_embedding_table
()
self
.
masked_lm
=
layers
.
MaskedLM
(
embedding_table
=
embedding_table
,
activation
=
mlm_activation
,
initializer
=
mlm_initializer
,
output
=
output_type
,
name
=
'generator_masked_lm'
)
discriminator_cfg
=
self
.
discriminator_mws_network
.
get_config
()
self
.
discriminator_rtd_head
=
ReplacedTokenDetectionHead
(
encoder_cfg
=
discriminator_cfg
,
num_task_agnostic_layers
=
num_discriminator_task_agnostic_layers
,
output
=
output_type
,
name
=
'discriminator_rtd'
)
hidden_cfg
=
discriminator_cfg
[
'hidden_cfg'
]
self
.
discriminator_mws_head
=
MultiWordSelectionHead
(
embedding_table
=
embedding_table
,
activation
=
hidden_cfg
[
'intermediate_activation'
],
initializer
=
hidden_cfg
[
'kernel_initializer'
],
output
=
output_type
,
name
=
'discriminator_mws'
)
self
.
num_task_agnostic_layers
=
num_discriminator_task_agnostic_layers
def
call
(
self
,
inputs
):
"""TEAMS forward pass.
Args:
inputs: A dict of all inputs, same as the standard BERT model.
Returns:
outputs: A dict of pretrainer model outputs, including
(1) lm_outputs: A `[batch_size, num_token_predictions, vocab_size]`
tensor indicating logits on masked positions.
(2) sentence_outputs: A `[batch_size, num_classes]` tensor indicating
logits for nsp task.
(3) disc_logits: A `[batch_size, sequence_length]` tensor indicating
logits for discriminator replaced token detection task.
(4) disc_label: A `[batch_size, sequence_length]` tensor indicating
target labels for discriminator replaced token detection task.
"""
input_word_ids
=
inputs
[
'input_word_ids'
]
input_mask
=
inputs
[
'input_mask'
]
input_type_ids
=
inputs
[
'input_type_ids'
]
masked_lm_positions
=
inputs
[
'masked_lm_positions'
]
# Runs generator.
sequence_output
=
self
.
generator_network
(
[
input_word_ids
,
input_mask
,
input_type_ids
])[
'sequence_output'
]
# The generator encoder network may get outputs from all layers.
if
isinstance
(
sequence_output
,
list
):
sequence_output
=
sequence_output
[
-
1
]
lm_outputs
=
self
.
masked_lm
(
sequence_output
,
masked_lm_positions
)
# Samples tokens from generator.
fake_data
=
self
.
_get_fake_data
(
inputs
,
lm_outputs
)
# Runs discriminator.
disc_input
=
fake_data
[
'inputs'
]
disc_rtd_label
=
fake_data
[
'is_fake_tokens'
]
disc_mws_candidates
=
fake_data
[
'candidate_set'
]
mws_sequence_outputs
=
self
.
discriminator_mws_network
([
disc_input
[
'input_word_ids'
],
disc_input
[
'input_mask'
],
disc_input
[
'input_type_ids'
]
])[
'encoder_outputs'
]
# Applies replaced token detection with input selected based on
# self.num_discriminator_task_agnostic_layers
disc_rtd_logits
=
self
.
discriminator_rtd_head
(
mws_sequence_outputs
[
self
.
num_task_agnostic_layers
-
1
],
input_mask
)
# Applies multi-word selection.
disc_mws_logits
=
self
.
discriminator_mws_head
(
mws_sequence_outputs
[
-
1
],
masked_lm_positions
,
disc_mws_candidates
)
outputs
=
{
'lm_outputs'
:
lm_outputs
,
'disc_rtd_logits'
:
disc_rtd_logits
,
'disc_rtd_label'
:
disc_rtd_label
,
'disc_mws_logits'
:
disc_mws_logits
,
}
return
outputs
def
_get_fake_data
(
self
,
inputs
,
mlm_logits
):
"""Generate corrupted data for discriminator.
Note it is poosible for sampled token to be the same as the correct one.
Args:
inputs: A dict of all inputs, same as the input of `call()` function
mlm_logits: The generator's output logits
Returns:
A dict of generated fake data
"""
inputs
=
models
.
electra_pretrainer
.
unmask
(
inputs
,
duplicate
=
True
)
# Samples replaced token.
sampled_tokens
=
tf
.
stop_gradient
(
models
.
electra_pretrainer
.
sample_from_softmax
(
mlm_logits
,
disallow
=
None
))
sampled_tokids
=
tf
.
argmax
(
sampled_tokens
,
-
1
,
output_type
=
tf
.
int32
)
# Prepares input and label for replaced token detection task.
updated_input_ids
,
masked
=
models
.
electra_pretrainer
.
scatter_update
(
inputs
[
'input_word_ids'
],
sampled_tokids
,
inputs
[
'masked_lm_positions'
])
rtd_labels
=
masked
*
(
1
-
tf
.
cast
(
tf
.
equal
(
updated_input_ids
,
inputs
[
'input_word_ids'
]),
tf
.
int32
))
updated_inputs
=
models
.
electra_pretrainer
.
get_updated_inputs
(
inputs
,
duplicate
=
True
,
input_word_ids
=
updated_input_ids
)
# Samples (candidate_size-1) negatives and concat with true tokens
disallow
=
tf
.
one_hot
(
inputs
[
'masked_lm_ids'
],
depth
=
self
.
vocab_size
,
dtype
=
tf
.
float32
)
sampled_candidates
=
tf
.
stop_gradient
(
sample_k_from_softmax
(
mlm_logits
,
k
=
self
.
candidate_size
-
1
,
disallow
=
disallow
))
true_token_id
=
tf
.
expand_dims
(
inputs
[
'masked_lm_ids'
],
-
1
)
candidate_set
=
tf
.
concat
([
true_token_id
,
sampled_candidates
],
-
1
)
return
{
'inputs'
:
updated_inputs
,
'is_fake_tokens'
:
rtd_labels
,
'sampled_tokens'
:
sampled_tokens
,
'candidate_set'
:
candidate_set
}
@
property
def
checkpoint_items
(
self
):
"""Returns a dictionary of items to be additionally checkpointed."""
items
=
dict
(
encoder
=
self
.
discriminator_network
)
return
items
def
get_config
(
self
):
return
self
.
_config
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
def
sample_k_from_softmax
(
logits
,
k
=
5
,
disallow
=
None
,
use_topk
=
False
):
"""Implement softmax sampling using gumbel softmax trick to select k items.
Args:
logits: A [batch_size, num_token_predictions, vocab_size] tensor indicating
the generator output logits for each masked position.
k: Number of samples
disallow: If `None`, we directly sample tokens from the logits. Otherwise,
this is a tensor of size [batch_size, num_token_predictions, vocab_size]
indicating the true word id in each masked position.
use_topk: Whether to use tf.nn.top_k or using approximate iterative approach
which is faster.
Returns:
sampled_tokens: A [batch_size, num_token_predictions, k] tensor indicating
the sampled word id in each masked position.
"""
if
use_topk
:
if
disallow
is
not
None
:
logits
-=
10000.0
*
disallow
uniform_noise
=
tf
.
random
.
uniform
(
tf_utils
.
get_shape_list
(
logits
),
minval
=
0
,
maxval
=
1
)
gumbel_noise
=
-
tf
.
math
.
log
(
-
tf
.
math
.
log
(
uniform_noise
+
1e-9
)
+
1e-9
)
_
,
sampled_tokens
=
tf
.
nn
.
top_k
(
logits
+
gumbel_noise
,
k
=
k
,
sorted
=
False
)
else
:
sampled_tokens_list
=
[]
vocab_size
=
tf_utils
.
get_shape_list
(
logits
)[
-
1
]
if
disallow
is
not
None
:
logits
-=
10000.0
*
disallow
uniform_noise
=
tf
.
random
.
uniform
(
tf_utils
.
get_shape_list
(
logits
),
minval
=
0
,
maxval
=
1
)
gumbel_noise
=
-
tf
.
math
.
log
(
-
tf
.
math
.
log
(
uniform_noise
+
1e-9
)
+
1e-9
)
logits
+=
gumbel_noise
for
_
in
range
(
k
):
token_ids
=
tf
.
argmax
(
logits
,
-
1
,
output_type
=
tf
.
int32
)
sampled_tokens_list
.
append
(
token_ids
)
logits
-=
10000.0
*
tf
.
one_hot
(
token_ids
,
depth
=
vocab_size
,
dtype
=
tf
.
float32
)
sampled_tokens
=
tf
.
stack
(
sampled_tokens_list
,
-
1
)
return
sampled_tokens
official/nlp/projects/teams/teams_pretrainer_test.py
0 → 100644
View file @
d544a3d8
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for TEAMS pre trainer network."""
import
tensorflow
as
tf
from
tensorflow.python.keras
import
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
from
official.modeling
import
activations
from
official.nlp.modeling.networks
import
encoder_scaffold
from
official.nlp.modeling.networks
import
packed_sequence_embedding
from
official.nlp.projects.teams
import
teams_pretrainer
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@
keras_parameterized
.
run_all_keras_modes
class
TeamsPretrainerTest
(
keras_parameterized
.
TestCase
):
# Build a transformer network to use within the TEAMS trainer.
def
_get_network
(
self
,
vocab_size
):
sequence_length
=
512
hidden_size
=
50
embedding_cfg
=
{
'vocab_size'
:
vocab_size
,
'type_vocab_size'
:
1
,
'hidden_size'
:
hidden_size
,
'embedding_width'
:
hidden_size
,
'max_seq_length'
:
sequence_length
,
'initializer'
:
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
),
'dropout_rate'
:
0.1
,
}
embedding_inst
=
packed_sequence_embedding
.
PackedSequenceEmbedding
(
**
embedding_cfg
)
hidden_cfg
=
{
'num_attention_heads'
:
2
,
'intermediate_size'
:
3072
,
'intermediate_activation'
:
activations
.
gelu
,
'dropout_rate'
:
0.1
,
'attention_dropout_rate'
:
0.1
,
'kernel_initializer'
:
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
),
}
return
encoder_scaffold
.
EncoderScaffold
(
num_hidden_instances
=
2
,
pooled_output_dim
=
hidden_size
,
embedding_cfg
=
embedding_cfg
,
embedding_cls
=
embedding_inst
,
hidden_cfg
=
hidden_cfg
,
dict_outputs
=
True
)
def
test_teams_pretrainer
(
self
):
"""Validate that the Keras object can be created."""
vocab_size
=
100
test_generator_network
=
self
.
_get_network
(
vocab_size
)
test_discriminator_network
=
self
.
_get_network
(
vocab_size
)
# Create a TEAMS trainer with the created network.
candidate_size
=
3
teams_trainer_model
=
teams_pretrainer
.
TeamsPretrainer
(
generator_network
=
test_generator_network
,
discriminator_mws_network
=
test_discriminator_network
,
num_discriminator_task_agnostic_layers
=
1
,
vocab_size
=
vocab_size
,
candidate_size
=
candidate_size
)
# Create a set of 2-dimensional inputs (the first dimension is implicit).
num_token_predictions
=
2
sequence_length
=
128
word_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
mask
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
type_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
lm_positions
=
tf
.
keras
.
Input
(
shape
=
(
num_token_predictions
,),
dtype
=
tf
.
int32
)
lm_ids
=
tf
.
keras
.
Input
(
shape
=
(
num_token_predictions
,),
dtype
=
tf
.
int32
)
inputs
=
{
'input_word_ids'
:
word_ids
,
'input_mask'
:
mask
,
'input_type_ids'
:
type_ids
,
'masked_lm_positions'
:
lm_positions
,
'masked_lm_ids'
:
lm_ids
}
# Invoke the trainer model on the inputs. This causes the layer to be built.
outputs
=
teams_trainer_model
(
inputs
)
lm_outs
=
outputs
[
'lm_outputs'
]
disc_rtd_logits
=
outputs
[
'disc_rtd_logits'
]
disc_rtd_label
=
outputs
[
'disc_rtd_label'
]
disc_mws_logits
=
outputs
[
'disc_mws_logits'
]
# Validate that the outputs are of the expected shape.
expected_lm_shape
=
[
None
,
num_token_predictions
,
vocab_size
]
expected_disc_rtd_logits_shape
=
[
None
,
sequence_length
]
expected_disc_rtd_label_shape
=
[
None
,
sequence_length
]
expected_disc_disc_mws_logits_shape
=
[
None
,
num_token_predictions
,
candidate_size
]
self
.
assertAllEqual
(
expected_lm_shape
,
lm_outs
.
shape
.
as_list
())
self
.
assertAllEqual
(
expected_disc_rtd_logits_shape
,
disc_rtd_logits
.
shape
.
as_list
())
self
.
assertAllEqual
(
expected_disc_rtd_label_shape
,
disc_rtd_label
.
shape
.
as_list
())
self
.
assertAllEqual
(
expected_disc_disc_mws_logits_shape
,
disc_mws_logits
.
shape
.
as_list
())
def
test_teams_trainer_tensor_call
(
self
):
"""Validate that the Keras object can be invoked."""
vocab_size
=
100
test_generator_network
=
self
.
_get_network
(
vocab_size
)
test_discriminator_network
=
self
.
_get_network
(
vocab_size
)
# Create a TEAMS trainer with the created network.
teams_trainer_model
=
teams_pretrainer
.
TeamsPretrainer
(
generator_network
=
test_generator_network
,
discriminator_mws_network
=
test_discriminator_network
,
num_discriminator_task_agnostic_layers
=
2
,
vocab_size
=
vocab_size
,
candidate_size
=
2
)
# Create a set of 2-dimensional data tensors to feed into the model.
word_ids
=
tf
.
constant
([[
1
,
1
,
1
],
[
2
,
2
,
2
]],
dtype
=
tf
.
int32
)
mask
=
tf
.
constant
([[
1
,
1
,
1
],
[
1
,
0
,
0
]],
dtype
=
tf
.
int32
)
type_ids
=
tf
.
constant
([[
1
,
1
,
1
],
[
2
,
2
,
2
]],
dtype
=
tf
.
int32
)
lm_positions
=
tf
.
constant
([[
0
,
1
],
[
0
,
2
]],
dtype
=
tf
.
int32
)
lm_ids
=
tf
.
constant
([[
10
,
20
],
[
20
,
30
]],
dtype
=
tf
.
int32
)
inputs
=
{
'input_word_ids'
:
word_ids
,
'input_mask'
:
mask
,
'input_type_ids'
:
type_ids
,
'masked_lm_positions'
:
lm_positions
,
'masked_lm_ids'
:
lm_ids
}
# Invoke the trainer model on the tensors. In Eager mode, this does the
# actual calculation. (We can't validate the outputs, since the network is
# too complex: this simply ensures we're not hitting runtime errors.)
_
=
teams_trainer_model
(
inputs
)
def
test_serialize_deserialize
(
self
):
"""Validate that the TEAMS trainer can be serialized and deserialized."""
vocab_size
=
100
test_generator_network
=
self
.
_get_network
(
vocab_size
)
test_discriminator_network
=
self
.
_get_network
(
vocab_size
)
# Create a TEAMS trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.)
teams_trainer_model
=
teams_pretrainer
.
TeamsPretrainer
(
generator_network
=
test_generator_network
,
discriminator_mws_network
=
test_discriminator_network
,
num_discriminator_task_agnostic_layers
=
2
,
vocab_size
=
vocab_size
,
candidate_size
=
2
)
# Create another TEAMS trainer via serialization and deserialization.
config
=
teams_trainer_model
.
get_config
()
new_teams_trainer_model
=
teams_pretrainer
.
TeamsPretrainer
.
from_config
(
config
)
# Validate that the config can be forced to JSON.
_
=
new_teams_trainer_model
.
to_json
()
# If the serialization was successful, the new config should match the old.
self
.
assertAllEqual
(
teams_trainer_model
.
get_config
(),
new_teams_trainer_model
.
get_config
())
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment