Commit 6c63efed authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 319697990
parent 36e786dc
...@@ -57,8 +57,9 @@ class BertModelsTest(tf.test.TestCase): ...@@ -57,8 +57,9 @@ class BertModelsTest(tf.test.TestCase):
inner_dim=10, num_classes=2, name="next_sentence") inner_dim=10, num_classes=2, name="next_sentence")
]) ])
encoder = bert.instantiate_bertpretrainer_from_cfg(config) encoder = bert.instantiate_bertpretrainer_from_cfg(config)
self.assertSameElements(encoder.checkpoint_items.keys(), self.assertSameElements(
["encoder", "next_sentence.pooler_dense"]) encoder.checkpoint_items.keys(),
["encoder", "masked_lm", "next_sentence.pooler_dense"])
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -217,7 +217,7 @@ class BertPretrainerV2(tf.keras.Model): ...@@ -217,7 +217,7 @@ class BertPretrainerV2(tf.keras.Model):
@property @property
def checkpoint_items(self): def checkpoint_items(self):
"""Returns a dictionary of items to be additionally checkpointed.""" """Returns a dictionary of items to be additionally checkpointed."""
items = dict(encoder=self.encoder_network) items = dict(encoder=self.encoder_network, masked_lm=self.masked_lm)
for head in self.classification_heads: for head in self.classification_heads:
for key, item in head.checkpoint_items.items(): for key, item in head.checkpoint_items.items():
items['.'.join([head.name, key])] = item items['.'.join([head.name, key])] = item
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Masked language task.""" """Masked language task."""
from absl import logging
import dataclasses import dataclasses
import tensorflow as tf import tensorflow as tf
...@@ -26,6 +27,7 @@ from official.nlp.data import data_loader_factory ...@@ -26,6 +27,7 @@ from official.nlp.data import data_loader_factory
@dataclasses.dataclass @dataclasses.dataclass
class MaskedLMConfig(cfg.TaskConfig): class MaskedLMConfig(cfg.TaskConfig):
"""The model config.""" """The model config."""
init_checkpoint: str = ''
model: bert.BertPretrainerConfig = bert.BertPretrainerConfig(cls_heads=[ model: bert.BertPretrainerConfig = bert.BertPretrainerConfig(cls_heads=[
bert.ClsHeadConfig( bert.ClsHeadConfig(
inner_dim=768, num_classes=2, dropout_rate=0.1, name='next_sentence') inner_dim=768, num_classes=2, dropout_rate=0.1, name='next_sentence')
...@@ -171,3 +173,17 @@ class MaskedLMTask(base_task.Task): ...@@ -171,3 +173,17 @@ class MaskedLMTask(base_task.Task):
aux_losses=model.losses) aux_losses=model.losses)
self.process_metrics(metrics, inputs, outputs) self.process_metrics(metrics, inputs, outputs)
return {self.loss: loss} return {self.loss: loss}
def initialize(self, model: tf.keras.Model):
ckpt_dir_or_file = self.task_config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if not ckpt_dir_or_file:
return
# Restoring all modules defined by the model, e.g. encoder, masked_lm and
# cls pooler. The best initialization may vary case by case.
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
...@@ -27,6 +27,7 @@ class MLMTaskTest(tf.test.TestCase): ...@@ -27,6 +27,7 @@ class MLMTaskTest(tf.test.TestCase):
def test_task(self): def test_task(self):
config = masked_lm.MaskedLMConfig( config = masked_lm.MaskedLMConfig(
init_checkpoint=self.get_temp_dir(),
model=bert.BertPretrainerConfig( model=bert.BertPretrainerConfig(
encoders.TransformerEncoderConfig(vocab_size=30522, num_layers=1), encoders.TransformerEncoderConfig(vocab_size=30522, num_layers=1),
num_masked_tokens=20, num_masked_tokens=20,
...@@ -49,6 +50,12 @@ class MLMTaskTest(tf.test.TestCase): ...@@ -49,6 +50,12 @@ class MLMTaskTest(tf.test.TestCase):
task.train_step(next(iterator), model, optimizer, metrics=metrics) task.train_step(next(iterator), model, optimizer, metrics=metrics)
task.validation_step(next(iterator), model, metrics=metrics) task.validation_step(next(iterator), model, metrics=metrics)
# Saves a checkpoint.
ckpt = tf.train.Checkpoint(
model=model, **model.checkpoint_items)
ckpt.save(config.init_checkpoint)
task.initialize(model)
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()
...@@ -282,5 +282,5 @@ class QuestionAnsweringTask(base_task.Task): ...@@ -282,5 +282,5 @@ class QuestionAnsweringTask(base_task.Task):
ckpt = tf.train.Checkpoint(**model.checkpoint_items) ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.read(ckpt_dir_or_file) status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched() status.expect_partial().assert_existing_objects_matched()
logging.info('finished loading pretrained checkpoint from %s', logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file) ckpt_dir_or_file)
...@@ -189,5 +189,5 @@ class SentencePredictionTask(base_task.Task): ...@@ -189,5 +189,5 @@ class SentencePredictionTask(base_task.Task):
ckpt = tf.train.Checkpoint(**pretrain2finetune_mapping) ckpt = tf.train.Checkpoint(**pretrain2finetune_mapping)
status = ckpt.read(ckpt_dir_or_file) status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched() status.expect_partial().assert_existing_objects_matched()
logging.info('finished loading pretrained checkpoint from %s', logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file) ckpt_dir_or_file)
...@@ -212,5 +212,5 @@ class TaggingTask(base_task.Task): ...@@ -212,5 +212,5 @@ class TaggingTask(base_task.Task):
ckpt = tf.train.Checkpoint(**model.checkpoint_items) ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file) status = ckpt.restore(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched() status.expect_partial().assert_existing_objects_matched()
logging.info('finished loading pretrained checkpoint from %s', logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file) ckpt_dir_or_file)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment