Commit 6c63efed authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 319697990
parent 36e786dc
......@@ -57,8 +57,9 @@ class BertModelsTest(tf.test.TestCase):
inner_dim=10, num_classes=2, name="next_sentence")
])
encoder = bert.instantiate_bertpretrainer_from_cfg(config)
self.assertSameElements(encoder.checkpoint_items.keys(),
["encoder", "next_sentence.pooler_dense"])
self.assertSameElements(
encoder.checkpoint_items.keys(),
["encoder", "masked_lm", "next_sentence.pooler_dense"])
if __name__ == "__main__":
......
......@@ -217,7 +217,7 @@ class BertPretrainerV2(tf.keras.Model):
@property
def checkpoint_items(self):
"""Returns a dictionary of items to be additionally checkpointed."""
items = dict(encoder=self.encoder_network)
items = dict(encoder=self.encoder_network, masked_lm=self.masked_lm)
for head in self.classification_heads:
for key, item in head.checkpoint_items.items():
items['.'.join([head.name, key])] = item
......
......@@ -14,6 +14,7 @@
# limitations under the License.
# ==============================================================================
"""Masked language task."""
from absl import logging
import dataclasses
import tensorflow as tf
......@@ -26,6 +27,7 @@ from official.nlp.data import data_loader_factory
@dataclasses.dataclass
class MaskedLMConfig(cfg.TaskConfig):
"""The model config."""
init_checkpoint: str = ''
model: bert.BertPretrainerConfig = bert.BertPretrainerConfig(cls_heads=[
bert.ClsHeadConfig(
inner_dim=768, num_classes=2, dropout_rate=0.1, name='next_sentence')
......@@ -171,3 +173,17 @@ class MaskedLMTask(base_task.Task):
aux_losses=model.losses)
self.process_metrics(metrics, inputs, outputs)
return {self.loss: loss}
def initialize(self, model: tf.keras.Model):
ckpt_dir_or_file = self.task_config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if not ckpt_dir_or_file:
return
# Restoring all modules defined by the model, e.g. encoder, masked_lm and
# cls pooler. The best initialization may vary case by case.
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
......@@ -27,6 +27,7 @@ class MLMTaskTest(tf.test.TestCase):
def test_task(self):
config = masked_lm.MaskedLMConfig(
init_checkpoint=self.get_temp_dir(),
model=bert.BertPretrainerConfig(
encoders.TransformerEncoderConfig(vocab_size=30522, num_layers=1),
num_masked_tokens=20,
......@@ -49,6 +50,12 @@ class MLMTaskTest(tf.test.TestCase):
task.train_step(next(iterator), model, optimizer, metrics=metrics)
task.validation_step(next(iterator), model, metrics=metrics)
# Saves a checkpoint.
ckpt = tf.train.Checkpoint(
model=model, **model.checkpoint_items)
ckpt.save(config.init_checkpoint)
task.initialize(model)
if __name__ == "__main__":
tf.test.main()
......@@ -282,5 +282,5 @@ class QuestionAnsweringTask(base_task.Task):
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info('finished loading pretrained checkpoint from %s',
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
......@@ -189,5 +189,5 @@ class SentencePredictionTask(base_task.Task):
ckpt = tf.train.Checkpoint(**pretrain2finetune_mapping)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info('finished loading pretrained checkpoint from %s',
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
......@@ -212,5 +212,5 @@ class TaggingTask(base_task.Task):
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info('finished loading pretrained checkpoint from %s',
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment