Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
8754fa31
Commit
8754fa31
authored
Jul 14, 2020
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 321238481
parent
b860406a
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
306 additions
and
19 deletions
+306
-19
official/nlp/configs/electra.py
official/nlp/configs/electra.py
+14
-5
official/nlp/configs/encoders.py
official/nlp/configs/encoders.py
+10
-5
official/nlp/modeling/models/electra_pretrainer.py
official/nlp/modeling/models/electra_pretrainer.py
+14
-6
official/nlp/modeling/models/electra_pretrainer_test.py
official/nlp/modeling/models/electra_pretrainer_test.py
+0
-3
official/nlp/tasks/electra_task.py
official/nlp/tasks/electra_task.py
+209
-0
official/nlp/tasks/electra_task_test.py
official/nlp/tasks/electra_task_test.py
+59
-0
No files found.
official/nlp/configs/electra.py
View file @
8754fa31
...
@@ -34,6 +34,8 @@ class ELECTRAPretrainerConfig(base_config.Config):
...
@@ -34,6 +34,8 @@ class ELECTRAPretrainerConfig(base_config.Config):
sequence_length
:
int
=
512
sequence_length
:
int
=
512
num_classes
:
int
=
2
num_classes
:
int
=
2
discriminator_loss_weight
:
float
=
50.0
discriminator_loss_weight
:
float
=
50.0
tie_embeddings
:
bool
=
True
disallow_correct
:
bool
=
False
generator_encoder
:
encoders
.
TransformerEncoderConfig
=
(
generator_encoder
:
encoders
.
TransformerEncoderConfig
=
(
encoders
.
TransformerEncoderConfig
())
encoders
.
TransformerEncoderConfig
())
discriminator_encoder
:
encoders
.
TransformerEncoderConfig
=
(
discriminator_encoder
:
encoders
.
TransformerEncoderConfig
=
(
...
@@ -60,23 +62,30 @@ def instantiate_pretrainer_from_cfg(
...
@@ -60,23 +62,30 @@ def instantiate_pretrainer_from_cfg(
"""Instantiates ElectraPretrainer from the config."""
"""Instantiates ElectraPretrainer from the config."""
generator_encoder_cfg
=
config
.
generator_encoder
generator_encoder_cfg
=
config
.
generator_encoder
discriminator_encoder_cfg
=
config
.
discriminator_encoder
discriminator_encoder_cfg
=
config
.
discriminator_encoder
if
generator_network
is
None
:
# Copy discriminator's embeddings to generator for easier model serialization.
generator_network
=
encoders
.
instantiate_encoder_from_cfg
(
generator_encoder_cfg
)
if
discriminator_network
is
None
:
if
discriminator_network
is
None
:
discriminator_network
=
encoders
.
instantiate_encoder_from_cfg
(
discriminator_network
=
encoders
.
instantiate_encoder_from_cfg
(
discriminator_encoder_cfg
)
discriminator_encoder_cfg
)
if
generator_network
is
None
:
if
config
.
tie_embeddings
:
embedding_layer
=
discriminator_network
.
get_embedding_layer
()
generator_network
=
encoders
.
instantiate_encoder_from_cfg
(
generator_encoder_cfg
,
embedding_layer
=
embedding_layer
)
else
:
generator_network
=
encoders
.
instantiate_encoder_from_cfg
(
generator_encoder_cfg
)
return
electra_pretrainer
.
ElectraPretrainer
(
return
electra_pretrainer
.
ElectraPretrainer
(
generator_network
=
generator_network
,
generator_network
=
generator_network
,
discriminator_network
=
discriminator_network
,
discriminator_network
=
discriminator_network
,
vocab_size
=
config
.
generator_encoder
.
vocab_size
,
vocab_size
=
config
.
generator_encoder
.
vocab_size
,
num_classes
=
config
.
num_classes
,
num_classes
=
config
.
num_classes
,
sequence_length
=
config
.
sequence_length
,
sequence_length
=
config
.
sequence_length
,
last_hidden_dim
=
config
.
generator_encoder
.
hidden_size
,
num_token_predictions
=
config
.
num_masked_tokens
,
num_token_predictions
=
config
.
num_masked_tokens
,
mlm_activation
=
tf_utils
.
get_activation
(
mlm_activation
=
tf_utils
.
get_activation
(
generator_encoder_cfg
.
hidden_activation
),
generator_encoder_cfg
.
hidden_activation
),
mlm_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
mlm_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
generator_encoder_cfg
.
initializer_range
),
stddev
=
generator_encoder_cfg
.
initializer_range
),
classification_heads
=
instantiate_classification_heads_from_cfgs
(
classification_heads
=
instantiate_classification_heads_from_cfgs
(
config
.
cls_heads
))
config
.
cls_heads
),
disallow_correct
=
config
.
disallow_correct
)
official/nlp/configs/encoders.py
View file @
8754fa31
...
@@ -17,12 +17,13 @@
...
@@ -17,12 +17,13 @@
Includes configurations and instantiation methods.
Includes configurations and instantiation methods.
"""
"""
from
typing
import
Optional
import
dataclasses
import
dataclasses
import
gin
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.modeling
import
tf_utils
from
official.modeling.hyperparams
import
base_config
from
official.modeling.hyperparams
import
base_config
from
official.nlp.modeling
import
layers
from
official.nlp.modeling
import
networks
from
official.nlp.modeling
import
networks
...
@@ -40,11 +41,13 @@ class TransformerEncoderConfig(base_config.Config):
...
@@ -40,11 +41,13 @@ class TransformerEncoderConfig(base_config.Config):
max_position_embeddings
:
int
=
512
max_position_embeddings
:
int
=
512
type_vocab_size
:
int
=
2
type_vocab_size
:
int
=
2
initializer_range
:
float
=
0.02
initializer_range
:
float
=
0.02
embedding_size
:
Optional
[
int
]
=
None
@
gin
.
configurable
def
instantiate_encoder_from_cfg
(
def
instantiate_encoder_from_cfg
(
config
:
TransformerEncoderConfig
,
config
:
TransformerEncoderConfig
,
encoder_cls
=
networks
.
TransformerEncoder
):
encoder_cls
=
networks
.
TransformerEncoder
,
embedding_layer
:
Optional
[
layers
.
OnDeviceEmbedding
]
=
None
):
"""Instantiate a Transformer encoder network from TransformerEncoderConfig."""
"""Instantiate a Transformer encoder network from TransformerEncoderConfig."""
if
encoder_cls
.
__name__
==
"EncoderScaffold"
:
if
encoder_cls
.
__name__
==
"EncoderScaffold"
:
embedding_cfg
=
dict
(
embedding_cfg
=
dict
(
...
@@ -91,5 +94,7 @@ def instantiate_encoder_from_cfg(config: TransformerEncoderConfig,
...
@@ -91,5 +94,7 @@ def instantiate_encoder_from_cfg(config: TransformerEncoderConfig,
max_sequence_length
=
config
.
max_position_embeddings
,
max_sequence_length
=
config
.
max_position_embeddings
,
type_vocab_size
=
config
.
type_vocab_size
,
type_vocab_size
=
config
.
type_vocab_size
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
config
.
initializer_range
))
stddev
=
config
.
initializer_range
),
embedding_width
=
config
.
embedding_size
,
embedding_layer
=
embedding_layer
)
return
encoder_network
return
encoder_network
official/nlp/modeling/models/electra_pretrainer.py
View file @
8754fa31
...
@@ -48,7 +48,6 @@ class ElectraPretrainer(tf.keras.Model):
...
@@ -48,7 +48,6 @@ class ElectraPretrainer(tf.keras.Model):
num_classes: Number of classes to predict from the classification network
num_classes: Number of classes to predict from the classification network
for the generator network (not used now)
for the generator network (not used now)
sequence_length: Input sequence length
sequence_length: Input sequence length
last_hidden_dim: Last hidden dim of generator transformer output
num_token_predictions: Number of tokens to predict from the masked LM.
num_token_predictions: Number of tokens to predict from the masked LM.
mlm_activation: The activation (if any) to use in the masked LM and
mlm_activation: The activation (if any) to use in the masked LM and
classification networks. If None, no activation will be used.
classification networks. If None, no activation will be used.
...
@@ -66,7 +65,6 @@ class ElectraPretrainer(tf.keras.Model):
...
@@ -66,7 +65,6 @@ class ElectraPretrainer(tf.keras.Model):
vocab_size
,
vocab_size
,
num_classes
,
num_classes
,
sequence_length
,
sequence_length
,
last_hidden_dim
,
num_token_predictions
,
num_token_predictions
,
mlm_activation
=
None
,
mlm_activation
=
None
,
mlm_initializer
=
'glorot_uniform'
,
mlm_initializer
=
'glorot_uniform'
,
...
@@ -80,7 +78,6 @@ class ElectraPretrainer(tf.keras.Model):
...
@@ -80,7 +78,6 @@ class ElectraPretrainer(tf.keras.Model):
'vocab_size'
:
vocab_size
,
'vocab_size'
:
vocab_size
,
'num_classes'
:
num_classes
,
'num_classes'
:
num_classes
,
'sequence_length'
:
sequence_length
,
'sequence_length'
:
sequence_length
,
'last_hidden_dim'
:
last_hidden_dim
,
'num_token_predictions'
:
num_token_predictions
,
'num_token_predictions'
:
num_token_predictions
,
'mlm_activation'
:
mlm_activation
,
'mlm_activation'
:
mlm_activation
,
'mlm_initializer'
:
mlm_initializer
,
'mlm_initializer'
:
mlm_initializer
,
...
@@ -95,7 +92,6 @@ class ElectraPretrainer(tf.keras.Model):
...
@@ -95,7 +92,6 @@ class ElectraPretrainer(tf.keras.Model):
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
self
.
sequence_length
=
sequence_length
self
.
sequence_length
=
sequence_length
self
.
last_hidden_dim
=
last_hidden_dim
self
.
num_token_predictions
=
num_token_predictions
self
.
num_token_predictions
=
num_token_predictions
self
.
mlm_activation
=
mlm_activation
self
.
mlm_activation
=
mlm_activation
self
.
mlm_initializer
=
mlm_initializer
self
.
mlm_initializer
=
mlm_initializer
...
@@ -108,10 +104,15 @@ class ElectraPretrainer(tf.keras.Model):
...
@@ -108,10 +104,15 @@ class ElectraPretrainer(tf.keras.Model):
output
=
output_type
,
output
=
output_type
,
name
=
'generator_masked_lm'
)
name
=
'generator_masked_lm'
)
self
.
classification
=
layers
.
ClassificationHead
(
self
.
classification
=
layers
.
ClassificationHead
(
inner_dim
=
last_
hidden_
dim
,
inner_dim
=
generator_network
.
_config_dict
[
'
hidden_
size'
]
,
num_classes
=
num_classes
,
num_classes
=
num_classes
,
initializer
=
mlm_initializer
,
initializer
=
mlm_initializer
,
name
=
'generator_classification_head'
)
name
=
'generator_classification_head'
)
self
.
discriminator_projection
=
tf
.
keras
.
layers
.
Dense
(
units
=
discriminator_network
.
_config_dict
[
'hidden_size'
],
activation
=
mlm_activation
,
kernel_initializer
=
mlm_initializer
,
name
=
'discriminator_projection_head'
)
self
.
discriminator_head
=
tf
.
keras
.
layers
.
Dense
(
self
.
discriminator_head
=
tf
.
keras
.
layers
.
Dense
(
units
=
1
,
kernel_initializer
=
mlm_initializer
)
units
=
1
,
kernel_initializer
=
mlm_initializer
)
...
@@ -165,7 +166,8 @@ class ElectraPretrainer(tf.keras.Model):
...
@@ -165,7 +166,8 @@ class ElectraPretrainer(tf.keras.Model):
if
isinstance
(
disc_sequence_output
,
list
):
if
isinstance
(
disc_sequence_output
,
list
):
disc_sequence_output
=
disc_sequence_output
[
-
1
]
disc_sequence_output
=
disc_sequence_output
[
-
1
]
disc_logits
=
self
.
discriminator_head
(
disc_sequence_output
)
disc_logits
=
self
.
discriminator_head
(
self
.
discriminator_projection
(
disc_sequence_output
))
disc_logits
=
tf
.
squeeze
(
disc_logits
,
axis
=-
1
)
disc_logits
=
tf
.
squeeze
(
disc_logits
,
axis
=-
1
)
outputs
=
{
outputs
=
{
...
@@ -214,6 +216,12 @@ class ElectraPretrainer(tf.keras.Model):
...
@@ -214,6 +216,12 @@ class ElectraPretrainer(tf.keras.Model):
'sampled_tokens'
:
sampled_tokens
'sampled_tokens'
:
sampled_tokens
}
}
@
property
def
checkpoint_items
(
self
):
"""Returns a dictionary of items to be additionally checkpointed."""
items
=
dict
(
encoder
=
self
.
discriminator_network
)
return
items
def
get_config
(
self
):
def
get_config
(
self
):
return
self
.
_config
return
self
.
_config
...
...
official/nlp/modeling/models/electra_pretrainer_test.py
View file @
8754fa31
...
@@ -49,7 +49,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
...
@@ -49,7 +49,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
vocab_size
=
vocab_size
,
vocab_size
=
vocab_size
,
num_classes
=
num_classes
,
num_classes
=
num_classes
,
sequence_length
=
sequence_length
,
sequence_length
=
sequence_length
,
last_hidden_dim
=
768
,
num_token_predictions
=
num_token_predictions
,
num_token_predictions
=
num_token_predictions
,
disallow_correct
=
True
)
disallow_correct
=
True
)
...
@@ -101,7 +100,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
...
@@ -101,7 +100,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
vocab_size
=
100
,
vocab_size
=
100
,
num_classes
=
2
,
num_classes
=
2
,
sequence_length
=
3
,
sequence_length
=
3
,
last_hidden_dim
=
768
,
num_token_predictions
=
2
)
num_token_predictions
=
2
)
# Create a set of 2-dimensional data tensors to feed into the model.
# Create a set of 2-dimensional data tensors to feed into the model.
...
@@ -140,7 +138,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
...
@@ -140,7 +138,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
vocab_size
=
100
,
vocab_size
=
100
,
num_classes
=
2
,
num_classes
=
2
,
sequence_length
=
3
,
sequence_length
=
3
,
last_hidden_dim
=
768
,
num_token_predictions
=
2
)
num_token_predictions
=
2
)
# Create another BERT trainer via serialization and deserialization.
# Create another BERT trainer via serialization and deserialization.
...
...
official/nlp/tasks/electra_task.py
0 → 100644
View file @
8754fa31
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ELECTRA pretraining task (Joint Masked LM and Replaced Token Detection)."""
import
dataclasses
import
tensorflow
as
tf
from
official.core
import
base_task
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
electra
from
official.nlp.data
import
pretrain_dataloader
@
dataclasses
.
dataclass
class
ELECTRAPretrainConfig
(
cfg
.
TaskConfig
):
"""The model config."""
model
:
electra
.
ELECTRAPretrainerConfig
=
electra
.
ELECTRAPretrainerConfig
(
cls_heads
=
[
bert
.
ClsHeadConfig
(
inner_dim
=
768
,
num_classes
=
2
,
dropout_rate
=
0.1
,
name
=
'next_sentence'
)
])
train_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
validation_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
@
base_task
.
register_task_cls
(
ELECTRAPretrainConfig
)
class
ELECTRAPretrainTask
(
base_task
.
Task
):
"""ELECTRA Pretrain Task (Masked LM + Replaced Token Detection)."""
def
build_model
(
self
):
return
electra
.
instantiate_pretrainer_from_cfg
(
self
.
task_config
.
model
)
def
build_losses
(
self
,
labels
,
model_outputs
,
metrics
,
aux_losses
=
None
)
->
tf
.
Tensor
:
metrics
=
dict
([(
metric
.
name
,
metric
)
for
metric
in
metrics
])
# generator lm and (optional) nsp loss.
lm_prediction_losses
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
labels
[
'masked_lm_ids'
],
tf
.
cast
(
model_outputs
[
'lm_outputs'
],
tf
.
float32
),
from_logits
=
True
)
lm_label_weights
=
labels
[
'masked_lm_weights'
]
lm_numerator_loss
=
tf
.
reduce_sum
(
lm_prediction_losses
*
lm_label_weights
)
lm_denominator_loss
=
tf
.
reduce_sum
(
lm_label_weights
)
mlm_loss
=
tf
.
math
.
divide_no_nan
(
lm_numerator_loss
,
lm_denominator_loss
)
metrics
[
'lm_example_loss'
].
update_state
(
mlm_loss
)
if
'next_sentence_labels'
in
labels
:
sentence_labels
=
labels
[
'next_sentence_labels'
]
sentence_outputs
=
tf
.
cast
(
model_outputs
[
'sentence_outputs'
],
dtype
=
tf
.
float32
)
sentence_loss
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
sentence_labels
,
sentence_outputs
,
from_logits
=
True
)
metrics
[
'next_sentence_loss'
].
update_state
(
sentence_loss
)
total_loss
=
mlm_loss
+
sentence_loss
else
:
total_loss
=
mlm_loss
# discriminator replaced token detection (rtd) loss.
rtd_logits
=
model_outputs
[
'disc_logits'
]
rtd_labels
=
tf
.
cast
(
model_outputs
[
'disc_label'
],
tf
.
float32
)
input_mask
=
tf
.
cast
(
labels
[
'input_mask'
],
tf
.
float32
)
rtd_ind_loss
=
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
logits
=
rtd_logits
,
labels
=
rtd_labels
)
rtd_numerator
=
tf
.
reduce_sum
(
input_mask
*
rtd_ind_loss
)
rtd_denominator
=
tf
.
reduce_sum
(
input_mask
)
rtd_loss
=
tf
.
math
.
divide_no_nan
(
rtd_numerator
,
rtd_denominator
)
metrics
[
'discriminator_loss'
].
update_state
(
rtd_loss
)
total_loss
=
total_loss
+
\
self
.
task_config
.
model
.
discriminator_loss_weight
*
rtd_loss
if
aux_losses
:
total_loss
+=
tf
.
add_n
(
aux_losses
)
metrics
[
'total_loss'
].
update_state
(
total_loss
)
return
total_loss
def
build_inputs
(
self
,
params
,
input_context
=
None
):
"""Returns tf.data.Dataset for pretraining."""
if
params
.
input_path
==
'dummy'
:
def
dummy_data
(
_
):
dummy_ids
=
tf
.
zeros
((
1
,
params
.
seq_length
),
dtype
=
tf
.
int32
)
dummy_lm
=
tf
.
zeros
((
1
,
params
.
max_predictions_per_seq
),
dtype
=
tf
.
int32
)
return
dict
(
input_word_ids
=
dummy_ids
,
input_mask
=
dummy_ids
,
input_type_ids
=
dummy_ids
,
masked_lm_positions
=
dummy_lm
,
masked_lm_ids
=
dummy_lm
,
masked_lm_weights
=
tf
.
cast
(
dummy_lm
,
dtype
=
tf
.
float32
),
next_sentence_labels
=
tf
.
zeros
((
1
,
1
),
dtype
=
tf
.
int32
))
dataset
=
tf
.
data
.
Dataset
.
range
(
1
)
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
map
(
dummy_data
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
return
pretrain_dataloader
.
BertPretrainDataLoader
(
params
).
load
(
input_context
)
def
build_metrics
(
self
,
training
=
None
):
del
training
metrics
=
[
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'masked_lm_accuracy'
),
tf
.
keras
.
metrics
.
Mean
(
name
=
'lm_example_loss'
),
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'discriminator_accuracy'
),
]
if
self
.
task_config
.
train_data
.
use_next_sentence_label
:
metrics
.
append
(
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'next_sentence_accuracy'
))
metrics
.
append
(
tf
.
keras
.
metrics
.
Mean
(
name
=
'next_sentence_loss'
))
metrics
.
append
(
tf
.
keras
.
metrics
.
Mean
(
name
=
'discriminator_loss'
))
metrics
.
append
(
tf
.
keras
.
metrics
.
Mean
(
name
=
'total_loss'
))
return
metrics
def
process_metrics
(
self
,
metrics
,
labels
,
model_outputs
):
metrics
=
dict
([(
metric
.
name
,
metric
)
for
metric
in
metrics
])
if
'masked_lm_accuracy'
in
metrics
:
metrics
[
'masked_lm_accuracy'
].
update_state
(
labels
[
'masked_lm_ids'
],
model_outputs
[
'lm_outputs'
],
labels
[
'masked_lm_weights'
])
if
'next_sentence_accuracy'
in
metrics
:
metrics
[
'next_sentence_accuracy'
].
update_state
(
labels
[
'next_sentence_labels'
],
model_outputs
[
'sentence_outputs'
])
if
'discriminator_accuracy'
in
metrics
:
disc_logits_expanded
=
tf
.
expand_dims
(
model_outputs
[
'disc_logits'
],
-
1
)
discrim_full_logits
=
tf
.
concat
(
[
-
1.0
*
disc_logits_expanded
,
disc_logits_expanded
],
-
1
)
metrics
[
'discriminator_accuracy'
].
update_state
(
model_outputs
[
'disc_label'
],
discrim_full_logits
,
labels
[
'input_mask'
])
def
train_step
(
self
,
inputs
,
model
:
tf
.
keras
.
Model
,
optimizer
:
tf
.
keras
.
optimizers
.
Optimizer
,
metrics
):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
with
tf
.
GradientTape
()
as
tape
:
outputs
=
model
(
inputs
,
training
=
True
)
# Computes per-replica loss.
loss
=
self
.
build_losses
(
labels
=
inputs
,
model_outputs
=
outputs
,
metrics
=
metrics
,
aux_losses
=
model
.
losses
)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
# TODO(b/154564893): enable loss scaling.
scaled_loss
=
loss
/
tf
.
distribute
.
get_strategy
().
num_replicas_in_sync
tvars
=
model
.
trainable_variables
grads
=
tape
.
gradient
(
scaled_loss
,
tvars
)
optimizer
.
apply_gradients
(
list
(
zip
(
grads
,
tvars
)))
self
.
process_metrics
(
metrics
,
inputs
,
outputs
)
return
{
self
.
loss
:
loss
}
def
validation_step
(
self
,
inputs
,
model
:
tf
.
keras
.
Model
,
metrics
):
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
outputs
=
model
(
inputs
,
training
=
False
)
loss
=
self
.
build_losses
(
labels
=
inputs
,
model_outputs
=
outputs
,
metrics
=
metrics
,
aux_losses
=
model
.
losses
)
self
.
process_metrics
(
metrics
,
inputs
,
outputs
)
return
{
self
.
loss
:
loss
}
official/nlp/tasks/electra_task_test.py
0 → 100644
View file @
8754fa31
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for official.nlp.tasks.electra_task."""
import
tensorflow
as
tf
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
electra
from
official.nlp.configs
import
encoders
from
official.nlp.data
import
pretrain_dataloader
from
official.nlp.tasks
import
electra_task
class
ELECTRAPretrainTaskTest
(
tf
.
test
.
TestCase
):
def
test_task
(
self
):
config
=
electra_task
.
ELECTRAPretrainConfig
(
model
=
electra
.
ELECTRAPretrainerConfig
(
generator_encoder
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
),
discriminator_encoder
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
),
num_masked_tokens
=
20
,
sequence_length
=
128
,
cls_heads
=
[
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
]),
train_data
=
pretrain_dataloader
.
BertPretrainDataConfig
(
input_path
=
"dummy"
,
max_predictions_per_seq
=
20
,
seq_length
=
128
,
global_batch_size
=
1
))
task
=
electra_task
.
ELECTRAPretrainTask
(
config
)
model
=
task
.
build_model
()
metrics
=
task
.
build_metrics
()
dataset
=
task
.
build_inputs
(
config
.
train_data
)
iterator
=
iter
(
dataset
)
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
lr
=
0.1
)
task
.
train_step
(
next
(
iterator
),
model
,
optimizer
,
metrics
=
metrics
)
task
.
validation_step
(
next
(
iterator
),
model
,
metrics
=
metrics
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment