Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
3db5347d
Commit
3db5347d
authored
Apr 13, 2020
by
Chen Chen
Committed by
A. Unique TensorFlower
Apr 13, 2020
Browse files
Allow to not use next sentence labels in pretraining.
PiperOrigin-RevId: 306324960
parent
1025682f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
72 additions
and
43 deletions
+72
-43
official/nlp/bert/bert_models.py
official/nlp/bert/bert_models.py
+51
-33
official/nlp/bert/input_pipeline.py
official/nlp/bert/input_pipeline.py
+7
-4
official/nlp/bert/run_pretraining.py
official/nlp/bert/run_pretraining.py
+14
-6
No files found.
official/nlp/bert/bert_models.py
View file @
3db5347d
...
...
@@ -54,29 +54,41 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
self
.
add_metric
(
lm_example_loss
,
name
=
'lm_example_loss'
,
aggregation
=
'mean'
)
next_sentence_accuracy
=
tf
.
keras
.
metrics
.
sparse_categorical_accuracy
(
sentence_labels
,
sentence_output
)
self
.
add_metric
(
next_sentence_accuracy
,
name
=
'next_sentence_accuracy'
,
aggregation
=
'mean'
)
self
.
add_metric
(
next_sentence_loss
,
name
=
'next_sentence_loss'
,
aggregation
=
'mean'
)
def
call
(
self
,
lm_output
,
sentence_output
,
lm_label_ids
,
lm_label_weights
,
sentence_labels
):
if
sentence_labels
is
not
None
:
next_sentence_accuracy
=
tf
.
keras
.
metrics
.
sparse_categorical_accuracy
(
sentence_labels
,
sentence_output
)
self
.
add_metric
(
next_sentence_accuracy
,
name
=
'next_sentence_accuracy'
,
aggregation
=
'mean'
)
if
next_sentence_loss
is
not
None
:
self
.
add_metric
(
next_sentence_loss
,
name
=
'next_sentence_loss'
,
aggregation
=
'mean'
)
def
call
(
self
,
lm_output
,
sentence_output
,
lm_label_ids
,
lm_label_weights
,
sentence_labels
=
None
):
"""Implements call() for the layer."""
lm_label_weights
=
tf
.
cast
(
lm_label_weights
,
tf
.
float32
)
lm_output
=
tf
.
cast
(
lm_output
,
tf
.
float32
)
sentence_output
=
tf
.
cast
(
sentence_output
,
tf
.
float32
)
mask_label_loss
=
losses
.
weighted_sparse_categorical_crossentropy_loss
(
labels
=
lm_label_ids
,
predictions
=
lm_output
,
weights
=
lm_label_weights
)
sentence_loss
=
losses
.
weighted_sparse_categorical_crossentropy_loss
(
labels
=
sentence_labels
,
predictions
=
sentence_output
)
loss
=
mask_label_loss
+
sentence_loss
batch_shape
=
tf
.
slice
(
tf
.
shape
(
sentence_labels
),
[
0
],
[
1
])
if
sentence_labels
is
not
None
:
sentence_output
=
tf
.
cast
(
sentence_output
,
tf
.
float32
)
sentence_loss
=
losses
.
weighted_sparse_categorical_crossentropy_loss
(
labels
=
sentence_labels
,
predictions
=
sentence_output
)
loss
=
mask_label_loss
+
sentence_loss
else
:
sentence_loss
=
None
loss
=
mask_label_loss
batch_shape
=
tf
.
slice
(
tf
.
shape
(
lm_label_ids
),
[
0
],
[
1
])
# TODO(hongkuny): Avoids the hack and switches add_loss.
final_loss
=
tf
.
fill
(
batch_shape
,
loss
)
...
...
@@ -155,7 +167,8 @@ def get_transformer_encoder(bert_config,
def
pretrain_model
(
bert_config
,
seq_length
,
max_predictions_per_seq
,
initializer
=
None
):
initializer
=
None
,
use_next_sentence_label
=
True
):
"""Returns model to be used for pre-training.
Args:
...
...
@@ -164,6 +177,7 @@ def pretrain_model(bert_config,
max_predictions_per_seq: Maximum number of tokens in sequence to mask out
and use for pretraining.
initializer: Initializer for weights in BertPretrainer.
use_next_sentence_label: Whether to use the next sentence label.
Returns:
Pretraining model as well as core BERT submodel from which to save
...
...
@@ -185,8 +199,12 @@ def pretrain_model(bert_config,
shape
=
(
max_predictions_per_seq
,),
name
=
'masked_lm_weights'
,
dtype
=
tf
.
int32
)
next_sentence_labels
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
1
,),
name
=
'next_sentence_labels'
,
dtype
=
tf
.
int32
)
if
use_next_sentence_label
:
next_sentence_labels
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
1
,),
name
=
'next_sentence_labels'
,
dtype
=
tf
.
int32
)
else
:
next_sentence_labels
=
None
transformer_encoder
=
get_transformer_encoder
(
bert_config
,
seq_length
)
if
initializer
is
None
:
...
...
@@ -206,17 +224,18 @@ def pretrain_model(bert_config,
vocab_size
=
bert_config
.
vocab_size
)
output_loss
=
pretrain_loss_layer
(
lm_output
,
sentence_output
,
masked_lm_ids
,
masked_lm_weights
,
next_sentence_labels
)
keras_model
=
tf
.
keras
.
Model
(
inputs
=
{
'input_word_ids'
:
input_word_ids
,
'input_mask'
:
input_mask
,
'input_type_ids'
:
input_type_ids
,
'masked_lm_positions'
:
masked_lm_positions
,
'masked_lm_ids'
:
masked_lm_ids
,
'masked_lm_weights'
:
masked_lm_weights
,
'next_sentence_labels'
:
next_sentence_labels
,
},
outputs
=
output_loss
)
inputs
=
{
'input_word_ids'
:
input_word_ids
,
'input_mask'
:
input_mask
,
'input_type_ids'
:
input_type_ids
,
'masked_lm_positions'
:
masked_lm_positions
,
'masked_lm_ids'
:
masked_lm_ids
,
'masked_lm_weights'
:
masked_lm_weights
,
}
if
use_next_sentence_label
:
inputs
[
'next_sentence_labels'
]
=
next_sentence_labels
keras_model
=
tf
.
keras
.
Model
(
inputs
=
inputs
,
outputs
=
output_loss
)
return
keras_model
,
transformer_encoder
...
...
@@ -313,8 +332,7 @@ def classifier_model(bert_config,
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'input_mask'
)
input_type_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'input_type_ids'
)
bert_model
=
hub
.
KerasLayer
(
hub_module_url
,
trainable
=
hub_module_trainable
)
bert_model
=
hub
.
KerasLayer
(
hub_module_url
,
trainable
=
hub_module_trainable
)
pooled_output
,
_
=
bert_model
([
input_word_ids
,
input_mask
,
input_type_ids
])
output
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
bert_config
.
hidden_dropout_prob
)(
pooled_output
)
...
...
official/nlp/bert/input_pipeline.py
View file @
3db5347d
...
...
@@ -59,7 +59,8 @@ def create_pretrain_dataset(input_patterns,
max_predictions_per_seq
,
batch_size
,
is_training
=
True
,
input_pipeline_context
=
None
):
input_pipeline_context
=
None
,
use_next_sentence_label
=
True
):
"""Creates input dataset from (tf)records files for pretraining."""
name_to_features
=
{
'input_ids'
:
...
...
@@ -74,9 +75,10 @@ def create_pretrain_dataset(input_patterns,
tf
.
io
.
FixedLenFeature
([
max_predictions_per_seq
],
tf
.
int64
),
'masked_lm_weights'
:
tf
.
io
.
FixedLenFeature
([
max_predictions_per_seq
],
tf
.
float32
),
'next_sentence_labels'
:
tf
.
io
.
FixedLenFeature
([
1
],
tf
.
int64
),
}
if
use_next_sentence_label
:
name_to_features
[
'next_sentence_labels'
]
=
tf
.
io
.
FixedLenFeature
([
1
],
tf
.
int64
)
for
input_pattern
in
input_patterns
:
if
not
tf
.
io
.
gfile
.
glob
(
input_pattern
):
...
...
@@ -118,8 +120,9 @@ def create_pretrain_dataset(input_patterns,
'masked_lm_positions'
:
record
[
'masked_lm_positions'
],
'masked_lm_ids'
:
record
[
'masked_lm_ids'
],
'masked_lm_weights'
:
record
[
'masked_lm_weights'
],
'next_sentence_labels'
:
record
[
'next_sentence_labels'
],
}
if
use_next_sentence_label
:
x
[
'next_sentence_labels'
]
=
record
[
'next_sentence_labels'
]
y
=
record
[
'masked_lm_weights'
]
...
...
official/nlp/bert/run_pretraining.py
View file @
3db5347d
...
...
@@ -47,6 +47,8 @@ flags.DEFINE_integer('num_steps_per_epoch', 1000,
'Total number of training steps to run per epoch.'
)
flags
.
DEFINE_float
(
'warmup_steps'
,
10000
,
'Warmup steps for Adam weight decay optimizer.'
)
flags
.
DEFINE_bool
(
'use_next_sentence_label'
,
True
,
'Whether to use next sentence label to compute final loss.'
)
common_flags
.
define_common_bert_flags
()
common_flags
.
define_gin_flags
()
...
...
@@ -55,7 +57,8 @@ FLAGS = flags.FLAGS
def
get_pretrain_dataset_fn
(
input_file_pattern
,
seq_length
,
max_predictions_per_seq
,
global_batch_size
):
max_predictions_per_seq
,
global_batch_size
,
use_next_sentence_label
=
True
):
"""Returns input dataset from input file string."""
def
_dataset_fn
(
ctx
=
None
):
"""Returns tf.data.Dataset for distributed BERT pretraining."""
...
...
@@ -67,7 +70,8 @@ def get_pretrain_dataset_fn(input_file_pattern, seq_length,
max_predictions_per_seq
,
batch_size
,
is_training
=
True
,
input_pipeline_context
=
ctx
)
input_pipeline_context
=
ctx
,
use_next_sentence_label
=
use_next_sentence_label
)
return
train_dataset
return
_dataset_fn
...
...
@@ -95,17 +99,20 @@ def run_customized_training(strategy,
end_lr
,
optimizer_type
,
input_files
,
train_batch_size
):
train_batch_size
,
use_next_sentence_label
=
True
):
"""Run BERT pretrain model training using low-level API."""
train_input_fn
=
get_pretrain_dataset_fn
(
input_files
,
max_seq_length
,
max_predictions_per_seq
,
train_batch_size
)
train_batch_size
,
use_next_sentence_label
)
def
_get_pretrain_model
():
"""Gets a pretraining model."""
pretrain_model
,
core_model
=
bert_models
.
pretrain_model
(
bert_config
,
max_seq_length
,
max_predictions_per_seq
)
bert_config
,
max_seq_length
,
max_predictions_per_seq
,
use_next_sentence_label
=
use_next_sentence_label
)
optimizer
=
optimization
.
create_optimizer
(
initial_lr
,
steps_per_epoch
*
epochs
,
warmup_steps
,
end_lr
,
optimizer_type
)
...
...
@@ -157,7 +164,8 @@ def run_bert_pretrain(strategy):
FLAGS
.
end_lr
,
FLAGS
.
optimizer_type
,
FLAGS
.
input_files
,
FLAGS
.
train_batch_size
)
FLAGS
.
train_batch_size
,
FLAGS
.
use_next_sentence_label
)
def
main
(
_
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment