Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
5e5cdec3
Commit
5e5cdec3
authored
Aug 30, 2021
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 393909662
parent
b8a51be8
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
62 additions
and
47 deletions
+62
-47
official/nlp/projects/teams/teams.py
official/nlp/projects/teams/teams.py
+42
-32
official/nlp/projects/teams/teams_pretrainer.py
official/nlp/projects/teams/teams_pretrainer.py
+16
-15
official/nlp/projects/teams/teams_pretrainer_test.py
official/nlp/projects/teams/teams_pretrainer_test.py
+4
-0
No files found.
official/nlp/projects/teams/teams.py
View file @
5e5cdec3
...
...
@@ -21,45 +21,52 @@ import tensorflow as tf
from
official.modeling
import
tf_utils
from
official.modeling.hyperparams
import
base_config
from
official.nlp.configs
import
encoders
from
official.nlp.modeling
import
layers
from
official.nlp.modeling
import
networks
@
dataclasses
.
dataclass
class
TeamsPretrainerConfig
(
base_config
.
Config
):
"""Teams pretrainer configuration."""
num_masked_tokens
:
int
=
76
sequence_length
:
int
=
512
num_classes
:
int
=
2
discriminator_loss_weight
:
float
=
50.0
# Candidate size for multi-word selection task, including the correct word.
candidate_size
:
int
=
5
# Weight for the generator masked language model task.
generator_loss_weight
:
float
=
1.0
# Weight for the replaced token detection task.
discriminator_rtd_loss_weight
:
float
=
5.0
# Weight for the multi-word selection task.
discriminator_mws_loss_weight
:
float
=
2.0
# Whether share embedding network between generator and discriminator.
tie_embeddings
:
bool
=
True
# Number of bottom layers shared between generator and discriminator.
# Non-positive value implies no sharing.
num_shared_generator_hidden_layers
:
int
=
3
# Number of bottom layers shared between different discriminator tasks.
num_discriminator_task_agnostic_layers
:
int
=
11
disallow_correct
:
bool
=
False
generator
:
encoders
.
BertEncoderConfig
=
encoders
.
BertEncoderConfig
()
discriminator
:
encoders
.
BertEncoderConfig
=
encoders
.
BertEncoderConfig
()
# Used for compatibility with continuous finetuning where common BERT config
# is used.
encoder
:
encoders
.
EncoderConfig
=
encoders
.
EncoderConfig
()
@
gin
.
configurable
def
get_encoder
(
bert_config
,
encoder_scaffold_cls
,
embedding_inst
=
None
,
hidden_inst
=
None
):
embedding_network
=
None
,
hidden_layers
=
layers
.
Transformer
):
"""Gets a 'EncoderScaffold' object.
Args:
bert_config: A 'modeling.BertConfig'.
encoder_scaffold_cls: An EncoderScaffold class.
embedding_inst: Embedding instance.
hidden_inst: List of hidden layer instances.
embedding_network: Embedding network instance.
hidden_layers: List of hidden layer instances.
Returns:
A encoder object.
"""
if
embedding_inst
is
not
None
:
# TODO(hongkuny): evaluate if it is better to put cfg definition in gin.
# embedding_size is required for PackedSequenceEmbedding.
if
bert_config
.
embedding_size
is
None
:
bert_config
.
embedding_size
=
bert_config
.
hidden_size
embedding_cfg
=
dict
(
vocab_size
=
bert_config
.
vocab_size
,
type_vocab_size
=
bert_config
.
type_vocab_size
,
...
...
@@ -68,27 +75,30 @@ def get_encoder(bert_config,
max_seq_length
=
bert_config
.
max_position_embeddings
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
bert_config
.
initializer_range
),
dropout_rate
=
bert_config
.
hidden_
dropout_
prob
,
dropout_rate
=
bert_config
.
dropout_
rate
,
)
embedding_inst
=
networks
.
PackedSequenceEmbedding
(
**
embedding_cfg
)
hidden_cfg
=
dict
(
num_attention_heads
=
bert_config
.
num_attention_heads
,
intermediate_size
=
bert_config
.
intermediate_size
,
intermediate_activation
=
tf_utils
.
get_activation
(
bert_config
.
hidden_act
),
dropout_rate
=
bert_config
.
hidden_dropout_prob
,
attention_dropout_rate
=
bert_config
.
attention_probs_dropout_prob
,
intermediate_activation
=
tf_utils
.
get_activation
(
bert_config
.
hidden_activation
),
dropout_rate
=
bert_config
.
dropout_rate
,
attention_dropout_rate
=
bert_config
.
attention_dropout_rate
,
kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
bert_config
.
initializer_range
),
)
if
embedding_network
is
None
:
embedding_network
=
networks
.
PackedSequenceEmbedding
(
**
embedding_cfg
)
kwargs
=
dict
(
embedding_cfg
=
embedding_cfg
,
embedding_cls
=
embedding_
inst
,
hidden_cls
=
hidden_
inst
,
embedding_cls
=
embedding_
network
,
hidden_cls
=
hidden_
layers
,
hidden_cfg
=
hidden_cfg
,
num_hidden_instances
=
bert_config
.
num_
hidden_
layers
,
num_hidden_instances
=
bert_config
.
num_layers
,
pooled_output_dim
=
bert_config
.
hidden_size
,
pooler_layer_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
bert_config
.
initializer_range
))
stddev
=
bert_config
.
initializer_range
),
dict_outputs
=
True
)
# Relies on gin configuration to define the Transformer encoder arguments.
return
encoder_scaffold
_cls
(
**
kwargs
)
return
networks
.
encoder_scaffold
.
EncoderScaffold
(
**
kwargs
)
official/nlp/projects/teams/teams_pretrainer.py
View file @
5e5cdec3
...
...
@@ -228,16 +228,14 @@ class TeamsPretrainer(tf.keras.Model):
num_discriminator_task_agnostic_layers: Number of layers shared between
multi-word selection and random token detection discriminators.
vocab_size: Size of generator output vocabulary
num_classes: Number of classes to predict from the classification network
for the generator network (not used now)
candidate_size: Candidate size for multi-word selection task,
including the correct word.
mlm_activation: The activation (if any) to use in the masked LM and
classification networks. If None, no activation will be used.
mlm_initializer: The initializer (if any) to use in the masked LM and
classification networks. Defaults to a Glorot uniform initializer.
output_type: The output style for this network. Can be either `logits` or
`predictions`.
disallow_correct: Whether to disallow the generator to generate the exact
same token in the original sentence
"""
def
__init__
(
self
,
...
...
@@ -311,12 +309,15 @@ class TeamsPretrainer(tf.keras.Model):
outputs: A dict of pretrainer model outputs, including
(1) lm_outputs: A `[batch_size, num_token_predictions, vocab_size]`
tensor indicating logits on masked positions.
(2) sentence_outputs: A `[batch_size, num_classes]` tensor indicating
logits for nsp task.
(3) disc_logits: A `[batch_size, sequence_length]` tensor indicating
(2) disc_rtd_logits: A `[batch_size, sequence_length]` tensor indicating
logits for discriminator replaced token detection task.
(
4
) disc_label: A `[batch_size, sequence_length]` tensor indicating
(
3
) disc_
rtd_
label: A `[batch_size, sequence_length]` tensor indicating
target labels for discriminator replaced token detection task.
(4) disc_mws_logits: A `[batch_size, num_token_predictions,
candidate_size]` tensor indicating logits for discriminator multi-word
selection task.
(5) disc_mws_labels: A `[batch_size, num_token_predictions]` tensor
indicating target labels for discriminator multi-word selection task.
"""
input_word_ids
=
inputs
[
'input_word_ids'
]
input_mask
=
inputs
[
'input_mask'
]
...
...
@@ -326,9 +327,6 @@ class TeamsPretrainer(tf.keras.Model):
# Runs generator.
sequence_output
=
self
.
generator_network
(
[
input_word_ids
,
input_mask
,
input_type_ids
])[
'sequence_output'
]
# The generator encoder network may get outputs from all layers.
if
isinstance
(
sequence_output
,
list
):
sequence_output
=
sequence_output
[
-
1
]
lm_outputs
=
self
.
masked_lm
(
sequence_output
,
masked_lm_positions
)
...
...
@@ -353,11 +351,14 @@ class TeamsPretrainer(tf.keras.Model):
disc_mws_logits
=
self
.
discriminator_mws_head
(
mws_sequence_outputs
[
-
1
],
masked_lm_positions
,
disc_mws_candidates
)
disc_mws_label
=
tf
.
zeros_like
(
masked_lm_positions
,
dtype
=
tf
.
int32
)
outputs
=
{
'lm_outputs'
:
lm_outputs
,
'disc_rtd_logits'
:
disc_rtd_logits
,
'disc_rtd_label'
:
disc_rtd_label
,
'disc_mws_logits'
:
disc_mws_logits
,
'disc_mws_label'
:
disc_mws_label
,
}
return
outputs
...
...
@@ -408,7 +409,7 @@ class TeamsPretrainer(tf.keras.Model):
@
property
def
checkpoint_items
(
self
):
"""Returns a dictionary of items to be additionally checkpointed."""
items
=
dict
(
encoder
=
self
.
discriminator_network
)
items
=
dict
(
encoder
=
self
.
discriminator_
mws_
network
)
return
items
def
get_config
(
self
):
...
...
@@ -419,7 +420,7 @@ class TeamsPretrainer(tf.keras.Model):
return
cls
(
**
config
)
def
sample_k_from_softmax
(
logits
,
k
=
5
,
disallow
=
None
,
use_topk
=
False
):
def
sample_k_from_softmax
(
logits
,
k
,
disallow
=
None
,
use_topk
=
False
):
"""Implement softmax sampling using gumbel softmax trick to select k items.
Args:
...
...
@@ -429,8 +430,8 @@ def sample_k_from_softmax(logits, k=5, disallow=None, use_topk=False):
disallow: If `None`, we directly sample tokens from the logits. Otherwise,
this is a tensor of size [batch_size, num_token_predictions, vocab_size]
indicating the true word id in each masked position.
use_topk: Whether to use tf.nn.top_k or using
approximate
iterative approach
which is
faster.
use_topk: Whether to use tf.nn.top_k or using iterative approach
where the
latter is empirically
faster.
Returns:
sampled_tokens: A [batch_size, num_token_predictions, k] tensor indicating
...
...
official/nlp/projects/teams/teams_pretrainer_test.py
View file @
5e5cdec3
...
...
@@ -103,6 +103,7 @@ class TeamsPretrainerTest(keras_parameterized.TestCase):
disc_rtd_logits
=
outputs
[
'disc_rtd_logits'
]
disc_rtd_label
=
outputs
[
'disc_rtd_label'
]
disc_mws_logits
=
outputs
[
'disc_mws_logits'
]
disc_mws_label
=
outputs
[
'disc_mws_label'
]
# Validate that the outputs are of the expected shape.
expected_lm_shape
=
[
None
,
num_token_predictions
,
vocab_size
]
...
...
@@ -111,6 +112,7 @@ class TeamsPretrainerTest(keras_parameterized.TestCase):
expected_disc_disc_mws_logits_shape
=
[
None
,
num_token_predictions
,
candidate_size
]
expected_disc_disc_mws_label_shape
=
[
None
,
num_token_predictions
]
self
.
assertAllEqual
(
expected_lm_shape
,
lm_outs
.
shape
.
as_list
())
self
.
assertAllEqual
(
expected_disc_rtd_logits_shape
,
disc_rtd_logits
.
shape
.
as_list
())
...
...
@@ -118,6 +120,8 @@ class TeamsPretrainerTest(keras_parameterized.TestCase):
disc_rtd_label
.
shape
.
as_list
())
self
.
assertAllEqual
(
expected_disc_disc_mws_logits_shape
,
disc_mws_logits
.
shape
.
as_list
())
self
.
assertAllEqual
(
expected_disc_disc_mws_label_shape
,
disc_mws_label
.
shape
.
as_list
())
def
test_teams_trainer_tensor_call
(
self
):
"""Validate that the Keras object can be invoked."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment