Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
3db5347d
Commit
3db5347d
authored
Apr 13, 2020
by
Chen Chen
Committed by
A. Unique TensorFlower
Apr 13, 2020
Browse files
Allow to not use next sentence labels in pretraining.
PiperOrigin-RevId: 306324960
parent
1025682f
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
72 additions
and
43 deletions
+72
-43
official/nlp/bert/bert_models.py
official/nlp/bert/bert_models.py
+51
-33
official/nlp/bert/input_pipeline.py
official/nlp/bert/input_pipeline.py
+7
-4
official/nlp/bert/run_pretraining.py
official/nlp/bert/run_pretraining.py
+14
-6
No files found.
official/nlp/bert/bert_models.py
View file @
3db5347d
...
@@ -54,6 +54,7 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
...
@@ -54,6 +54,7 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
self
.
add_metric
(
lm_example_loss
,
name
=
'lm_example_loss'
,
aggregation
=
'mean'
)
self
.
add_metric
(
lm_example_loss
,
name
=
'lm_example_loss'
,
aggregation
=
'mean'
)
if
sentence_labels
is
not
None
:
next_sentence_accuracy
=
tf
.
keras
.
metrics
.
sparse_categorical_accuracy
(
next_sentence_accuracy
=
tf
.
keras
.
metrics
.
sparse_categorical_accuracy
(
sentence_labels
,
sentence_output
)
sentence_labels
,
sentence_output
)
self
.
add_metric
(
self
.
add_metric
(
...
@@ -61,22 +62,33 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
...
@@ -61,22 +62,33 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
name
=
'next_sentence_accuracy'
,
name
=
'next_sentence_accuracy'
,
aggregation
=
'mean'
)
aggregation
=
'mean'
)
if
next_sentence_loss
is
not
None
:
self
.
add_metric
(
self
.
add_metric
(
next_sentence_loss
,
name
=
'next_sentence_loss'
,
aggregation
=
'mean'
)
next_sentence_loss
,
name
=
'next_sentence_loss'
,
aggregation
=
'mean'
)
def
call
(
self
,
lm_output
,
sentence_output
,
lm_label_ids
,
lm_label_weights
,
def
call
(
self
,
sentence_labels
):
lm_output
,
sentence_output
,
lm_label_ids
,
lm_label_weights
,
sentence_labels
=
None
):
"""Implements call() for the layer."""
"""Implements call() for the layer."""
lm_label_weights
=
tf
.
cast
(
lm_label_weights
,
tf
.
float32
)
lm_label_weights
=
tf
.
cast
(
lm_label_weights
,
tf
.
float32
)
lm_output
=
tf
.
cast
(
lm_output
,
tf
.
float32
)
lm_output
=
tf
.
cast
(
lm_output
,
tf
.
float32
)
sentence_output
=
tf
.
cast
(
sentence_output
,
tf
.
float32
)
mask_label_loss
=
losses
.
weighted_sparse_categorical_crossentropy_loss
(
mask_label_loss
=
losses
.
weighted_sparse_categorical_crossentropy_loss
(
labels
=
lm_label_ids
,
predictions
=
lm_output
,
weights
=
lm_label_weights
)
labels
=
lm_label_ids
,
predictions
=
lm_output
,
weights
=
lm_label_weights
)
if
sentence_labels
is
not
None
:
sentence_output
=
tf
.
cast
(
sentence_output
,
tf
.
float32
)
sentence_loss
=
losses
.
weighted_sparse_categorical_crossentropy_loss
(
sentence_loss
=
losses
.
weighted_sparse_categorical_crossentropy_loss
(
labels
=
sentence_labels
,
predictions
=
sentence_output
)
labels
=
sentence_labels
,
predictions
=
sentence_output
)
loss
=
mask_label_loss
+
sentence_loss
loss
=
mask_label_loss
+
sentence_loss
batch_shape
=
tf
.
slice
(
tf
.
shape
(
sentence_labels
),
[
0
],
[
1
])
else
:
sentence_loss
=
None
loss
=
mask_label_loss
batch_shape
=
tf
.
slice
(
tf
.
shape
(
lm_label_ids
),
[
0
],
[
1
])
# TODO(hongkuny): Avoids the hack and switches add_loss.
# TODO(hongkuny): Avoids the hack and switches add_loss.
final_loss
=
tf
.
fill
(
batch_shape
,
loss
)
final_loss
=
tf
.
fill
(
batch_shape
,
loss
)
...
@@ -155,7 +167,8 @@ def get_transformer_encoder(bert_config,
...
@@ -155,7 +167,8 @@ def get_transformer_encoder(bert_config,
def
pretrain_model
(
bert_config
,
def
pretrain_model
(
bert_config
,
seq_length
,
seq_length
,
max_predictions_per_seq
,
max_predictions_per_seq
,
initializer
=
None
):
initializer
=
None
,
use_next_sentence_label
=
True
):
"""Returns model to be used for pre-training.
"""Returns model to be used for pre-training.
Args:
Args:
...
@@ -164,6 +177,7 @@ def pretrain_model(bert_config,
...
@@ -164,6 +177,7 @@ def pretrain_model(bert_config,
max_predictions_per_seq: Maximum number of tokens in sequence to mask out
max_predictions_per_seq: Maximum number of tokens in sequence to mask out
and use for pretraining.
and use for pretraining.
initializer: Initializer for weights in BertPretrainer.
initializer: Initializer for weights in BertPretrainer.
use_next_sentence_label: Whether to use the next sentence label.
Returns:
Returns:
Pretraining model as well as core BERT submodel from which to save
Pretraining model as well as core BERT submodel from which to save
...
@@ -185,8 +199,12 @@ def pretrain_model(bert_config,
...
@@ -185,8 +199,12 @@ def pretrain_model(bert_config,
shape
=
(
max_predictions_per_seq
,),
shape
=
(
max_predictions_per_seq
,),
name
=
'masked_lm_weights'
,
name
=
'masked_lm_weights'
,
dtype
=
tf
.
int32
)
dtype
=
tf
.
int32
)
if
use_next_sentence_label
:
next_sentence_labels
=
tf
.
keras
.
layers
.
Input
(
next_sentence_labels
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
1
,),
name
=
'next_sentence_labels'
,
dtype
=
tf
.
int32
)
shape
=
(
1
,),
name
=
'next_sentence_labels'
,
dtype
=
tf
.
int32
)
else
:
next_sentence_labels
=
None
transformer_encoder
=
get_transformer_encoder
(
bert_config
,
seq_length
)
transformer_encoder
=
get_transformer_encoder
(
bert_config
,
seq_length
)
if
initializer
is
None
:
if
initializer
is
None
:
...
@@ -206,17 +224,18 @@ def pretrain_model(bert_config,
...
@@ -206,17 +224,18 @@ def pretrain_model(bert_config,
vocab_size
=
bert_config
.
vocab_size
)
vocab_size
=
bert_config
.
vocab_size
)
output_loss
=
pretrain_loss_layer
(
lm_output
,
sentence_output
,
masked_lm_ids
,
output_loss
=
pretrain_loss_layer
(
lm_output
,
sentence_output
,
masked_lm_ids
,
masked_lm_weights
,
next_sentence_labels
)
masked_lm_weights
,
next_sentence_labels
)
keras_model
=
tf
.
keras
.
Model
(
inputs
=
{
inputs
=
{
'input_word_ids'
:
input_word_ids
,
'input_word_ids'
:
input_word_ids
,
'input_mask'
:
input_mask
,
'input_mask'
:
input_mask
,
'input_type_ids'
:
input_type_ids
,
'input_type_ids'
:
input_type_ids
,
'masked_lm_positions'
:
masked_lm_positions
,
'masked_lm_positions'
:
masked_lm_positions
,
'masked_lm_ids'
:
masked_lm_ids
,
'masked_lm_ids'
:
masked_lm_ids
,
'masked_lm_weights'
:
masked_lm_weights
,
'masked_lm_weights'
:
masked_lm_weights
,
'next_sentence_labels'
:
next_sentence_labels
,
}
},
if
use_next_sentence_label
:
outputs
=
output_loss
)
inputs
[
'next_sentence_labels'
]
=
next_sentence_labels
keras_model
=
tf
.
keras
.
Model
(
inputs
=
inputs
,
outputs
=
output_loss
)
return
keras_model
,
transformer_encoder
return
keras_model
,
transformer_encoder
...
@@ -313,8 +332,7 @@ def classifier_model(bert_config,
...
@@ -313,8 +332,7 @@ def classifier_model(bert_config,
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'input_mask'
)
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'input_mask'
)
input_type_ids
=
tf
.
keras
.
layers
.
Input
(
input_type_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'input_type_ids'
)
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'input_type_ids'
)
bert_model
=
hub
.
KerasLayer
(
bert_model
=
hub
.
KerasLayer
(
hub_module_url
,
trainable
=
hub_module_trainable
)
hub_module_url
,
trainable
=
hub_module_trainable
)
pooled_output
,
_
=
bert_model
([
input_word_ids
,
input_mask
,
input_type_ids
])
pooled_output
,
_
=
bert_model
([
input_word_ids
,
input_mask
,
input_type_ids
])
output
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
bert_config
.
hidden_dropout_prob
)(
output
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
bert_config
.
hidden_dropout_prob
)(
pooled_output
)
pooled_output
)
...
...
official/nlp/bert/input_pipeline.py
View file @
3db5347d
...
@@ -59,7 +59,8 @@ def create_pretrain_dataset(input_patterns,
...
@@ -59,7 +59,8 @@ def create_pretrain_dataset(input_patterns,
max_predictions_per_seq
,
max_predictions_per_seq
,
batch_size
,
batch_size
,
is_training
=
True
,
is_training
=
True
,
input_pipeline_context
=
None
):
input_pipeline_context
=
None
,
use_next_sentence_label
=
True
):
"""Creates input dataset from (tf)records files for pretraining."""
"""Creates input dataset from (tf)records files for pretraining."""
name_to_features
=
{
name_to_features
=
{
'input_ids'
:
'input_ids'
:
...
@@ -74,9 +75,10 @@ def create_pretrain_dataset(input_patterns,
...
@@ -74,9 +75,10 @@ def create_pretrain_dataset(input_patterns,
tf
.
io
.
FixedLenFeature
([
max_predictions_per_seq
],
tf
.
int64
),
tf
.
io
.
FixedLenFeature
([
max_predictions_per_seq
],
tf
.
int64
),
'masked_lm_weights'
:
'masked_lm_weights'
:
tf
.
io
.
FixedLenFeature
([
max_predictions_per_seq
],
tf
.
float32
),
tf
.
io
.
FixedLenFeature
([
max_predictions_per_seq
],
tf
.
float32
),
'next_sentence_labels'
:
tf
.
io
.
FixedLenFeature
([
1
],
tf
.
int64
),
}
}
if
use_next_sentence_label
:
name_to_features
[
'next_sentence_labels'
]
=
tf
.
io
.
FixedLenFeature
([
1
],
tf
.
int64
)
for
input_pattern
in
input_patterns
:
for
input_pattern
in
input_patterns
:
if
not
tf
.
io
.
gfile
.
glob
(
input_pattern
):
if
not
tf
.
io
.
gfile
.
glob
(
input_pattern
):
...
@@ -118,8 +120,9 @@ def create_pretrain_dataset(input_patterns,
...
@@ -118,8 +120,9 @@ def create_pretrain_dataset(input_patterns,
'masked_lm_positions'
:
record
[
'masked_lm_positions'
],
'masked_lm_positions'
:
record
[
'masked_lm_positions'
],
'masked_lm_ids'
:
record
[
'masked_lm_ids'
],
'masked_lm_ids'
:
record
[
'masked_lm_ids'
],
'masked_lm_weights'
:
record
[
'masked_lm_weights'
],
'masked_lm_weights'
:
record
[
'masked_lm_weights'
],
'next_sentence_labels'
:
record
[
'next_sentence_labels'
],
}
}
if
use_next_sentence_label
:
x
[
'next_sentence_labels'
]
=
record
[
'next_sentence_labels'
]
y
=
record
[
'masked_lm_weights'
]
y
=
record
[
'masked_lm_weights'
]
...
...
official/nlp/bert/run_pretraining.py
View file @
3db5347d
...
@@ -47,6 +47,8 @@ flags.DEFINE_integer('num_steps_per_epoch', 1000,
...
@@ -47,6 +47,8 @@ flags.DEFINE_integer('num_steps_per_epoch', 1000,
'Total number of training steps to run per epoch.'
)
'Total number of training steps to run per epoch.'
)
flags
.
DEFINE_float
(
'warmup_steps'
,
10000
,
flags
.
DEFINE_float
(
'warmup_steps'
,
10000
,
'Warmup steps for Adam weight decay optimizer.'
)
'Warmup steps for Adam weight decay optimizer.'
)
flags
.
DEFINE_bool
(
'use_next_sentence_label'
,
True
,
'Whether to use next sentence label to compute final loss.'
)
common_flags
.
define_common_bert_flags
()
common_flags
.
define_common_bert_flags
()
common_flags
.
define_gin_flags
()
common_flags
.
define_gin_flags
()
...
@@ -55,7 +57,8 @@ FLAGS = flags.FLAGS
...
@@ -55,7 +57,8 @@ FLAGS = flags.FLAGS
def
get_pretrain_dataset_fn
(
input_file_pattern
,
seq_length
,
def
get_pretrain_dataset_fn
(
input_file_pattern
,
seq_length
,
max_predictions_per_seq
,
global_batch_size
):
max_predictions_per_seq
,
global_batch_size
,
use_next_sentence_label
=
True
):
"""Returns input dataset from input file string."""
"""Returns input dataset from input file string."""
def
_dataset_fn
(
ctx
=
None
):
def
_dataset_fn
(
ctx
=
None
):
"""Returns tf.data.Dataset for distributed BERT pretraining."""
"""Returns tf.data.Dataset for distributed BERT pretraining."""
...
@@ -67,7 +70,8 @@ def get_pretrain_dataset_fn(input_file_pattern, seq_length,
...
@@ -67,7 +70,8 @@ def get_pretrain_dataset_fn(input_file_pattern, seq_length,
max_predictions_per_seq
,
max_predictions_per_seq
,
batch_size
,
batch_size
,
is_training
=
True
,
is_training
=
True
,
input_pipeline_context
=
ctx
)
input_pipeline_context
=
ctx
,
use_next_sentence_label
=
use_next_sentence_label
)
return
train_dataset
return
train_dataset
return
_dataset_fn
return
_dataset_fn
...
@@ -95,17 +99,20 @@ def run_customized_training(strategy,
...
@@ -95,17 +99,20 @@ def run_customized_training(strategy,
end_lr
,
end_lr
,
optimizer_type
,
optimizer_type
,
input_files
,
input_files
,
train_batch_size
):
train_batch_size
,
use_next_sentence_label
=
True
):
"""Run BERT pretrain model training using low-level API."""
"""Run BERT pretrain model training using low-level API."""
train_input_fn
=
get_pretrain_dataset_fn
(
input_files
,
max_seq_length
,
train_input_fn
=
get_pretrain_dataset_fn
(
input_files
,
max_seq_length
,
max_predictions_per_seq
,
max_predictions_per_seq
,
train_batch_size
)
train_batch_size
,
use_next_sentence_label
)
def
_get_pretrain_model
():
def
_get_pretrain_model
():
"""Gets a pretraining model."""
"""Gets a pretraining model."""
pretrain_model
,
core_model
=
bert_models
.
pretrain_model
(
pretrain_model
,
core_model
=
bert_models
.
pretrain_model
(
bert_config
,
max_seq_length
,
max_predictions_per_seq
)
bert_config
,
max_seq_length
,
max_predictions_per_seq
,
use_next_sentence_label
=
use_next_sentence_label
)
optimizer
=
optimization
.
create_optimizer
(
optimizer
=
optimization
.
create_optimizer
(
initial_lr
,
steps_per_epoch
*
epochs
,
warmup_steps
,
initial_lr
,
steps_per_epoch
*
epochs
,
warmup_steps
,
end_lr
,
optimizer_type
)
end_lr
,
optimizer_type
)
...
@@ -157,7 +164,8 @@ def run_bert_pretrain(strategy):
...
@@ -157,7 +164,8 @@ def run_bert_pretrain(strategy):
FLAGS
.
end_lr
,
FLAGS
.
end_lr
,
FLAGS
.
optimizer_type
,
FLAGS
.
optimizer_type
,
FLAGS
.
input_files
,
FLAGS
.
input_files
,
FLAGS
.
train_batch_size
)
FLAGS
.
train_batch_size
,
FLAGS
.
use_next_sentence_label
)
def
main
(
_
):
def
main
(
_
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment