Commit 3dccfae1 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 311602262
parent 8c408bbe
...@@ -63,7 +63,8 @@ def create_pretrain_dataset(input_patterns, ...@@ -63,7 +63,8 @@ def create_pretrain_dataset(input_patterns,
is_training=True, is_training=True,
input_pipeline_context=None, input_pipeline_context=None,
use_next_sentence_label=True, use_next_sentence_label=True,
use_position_id=False): use_position_id=False,
output_fake_labels=True):
"""Creates input dataset from (tf)records files for pretraining.""" """Creates input dataset from (tf)records files for pretraining."""
name_to_features = { name_to_features = {
'input_ids': 'input_ids':
...@@ -135,9 +136,11 @@ def create_pretrain_dataset(input_patterns, ...@@ -135,9 +136,11 @@ def create_pretrain_dataset(input_patterns,
if use_position_id: if use_position_id:
x['position_ids'] = record['position_ids'] x['position_ids'] = record['position_ids']
y = record['masked_lm_weights'] # TODO(hongkuny): Remove the fake labels after migrating bert pretraining.
if output_fake_labels:
return (x, y) return (x, record['masked_lm_weights'])
else:
return x
dataset = dataset.map( dataset = dataset.map(
_select_data_from_record, _select_data_from_record,
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Functions and classes related to optimization (weight updates).""" """Functions and classes related to optimization (weight updates)."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
...@@ -21,6 +20,7 @@ from __future__ import print_function ...@@ -21,6 +20,7 @@ from __future__ import print_function
import re import re
from absl import logging from absl import logging
import gin
import tensorflow as tf import tensorflow as tf
import tensorflow_addons.optimizers as tfa_optimizers import tensorflow_addons.optimizers as tfa_optimizers
...@@ -67,6 +67,7 @@ class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule): ...@@ -67,6 +67,7 @@ class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
} }
@gin.configurable
def create_optimizer(init_lr, def create_optimizer(init_lr,
num_train_steps, num_train_steps,
num_warmup_steps, num_warmup_steps,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment