Commit 4b46ab20 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 319114361
parent 3300fa04
......@@ -280,7 +280,7 @@ class QuestionAnsweringTask(base_task.Task):
return
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info('finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
......@@ -35,6 +35,7 @@ class SentencePredictionConfig(cfg.TaskConfig):
# At most one of `init_checkpoint` and `hub_module_url` can
# be specified.
init_checkpoint: str = ''
init_cls_pooler: bool = False
hub_module_url: str = ''
metric_type: str = 'accuracy'
model: bert.BertPretrainerConfig = bert.BertPretrainerConfig(
......@@ -58,7 +59,7 @@ class SentencePredictionTask(base_task.Task):
super(SentencePredictionTask, self).__init__(params)
if params.hub_module_url and params.init_checkpoint:
raise ValueError('At most one of `hub_module_url` and '
'`pretrain_checkpoint_dir` can be specified.')
'`init_checkpoint` can be specified.')
if params.hub_module_url:
self._hub_module = hub.load(params.hub_module_url)
else:
......@@ -178,13 +179,16 @@ class SentencePredictionTask(base_task.Task):
return
pretrain2finetune_mapping = {
'encoder':
model.checkpoint_items['encoder'],
'next_sentence.pooler_dense':
model.checkpoint_items['sentence_prediction.pooler_dense'],
'encoder': model.checkpoint_items['encoder'],
}
# TODO(b/160251903): Investigate why no pooler dense improves finetuning
# accuracies.
if self.task_config.init_cls_pooler:
pretrain2finetune_mapping[
'next_sentence.pooler_dense'] = model.checkpoint_items[
'sentence_prediction.pooler_dense']
ckpt = tf.train.Checkpoint(**pretrain2finetune_mapping)
status = ckpt.restore(ckpt_dir_or_file)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info('finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment