"git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "d1752c0f49a4e494db31e9314c1619025d9e68ce"
Commit 3dccfae1 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 311602262
parent 8c408bbe
......@@ -63,7 +63,8 @@ def create_pretrain_dataset(input_patterns,
is_training=True,
input_pipeline_context=None,
use_next_sentence_label=True,
use_position_id=False):
use_position_id=False,
output_fake_labels=True):
"""Creates input dataset from (tf)records files for pretraining."""
name_to_features = {
'input_ids':
......@@ -135,9 +136,11 @@ def create_pretrain_dataset(input_patterns,
if use_position_id:
x['position_ids'] = record['position_ids']
y = record['masked_lm_weights']
return (x, y)
# TODO(hongkuny): Remove the fake labels after migrating bert pretraining.
if output_fake_labels:
return (x, record['masked_lm_weights'])
else:
return x
dataset = dataset.map(
_select_data_from_record,
......
......@@ -13,7 +13,6 @@
# limitations under the License.
# ==============================================================================
"""Functions and classes related to optimization (weight updates)."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
......@@ -21,6 +20,7 @@ from __future__ import print_function
import re
from absl import logging
import gin
import tensorflow as tf
import tensorflow_addons.optimizers as tfa_optimizers
......@@ -67,6 +67,7 @@ class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
}
@gin.configurable
def create_optimizer(init_lr,
num_train_steps,
num_warmup_steps,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment