"examples/vscode:/vscode.git/clone" did not exist on "72a98a86c6bc17573cbe61a26061689b3830e5af"
Commit 4b46ab20 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 319114361
parent 3300fa04
...@@ -280,7 +280,7 @@ class QuestionAnsweringTask(base_task.Task): ...@@ -280,7 +280,7 @@ class QuestionAnsweringTask(base_task.Task):
return return
ckpt = tf.train.Checkpoint(**model.checkpoint_items) ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file) status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched() status.expect_partial().assert_existing_objects_matched()
logging.info('finished loading pretrained checkpoint from %s', logging.info('finished loading pretrained checkpoint from %s',
ckpt_dir_or_file) ckpt_dir_or_file)
...@@ -35,6 +35,7 @@ class SentencePredictionConfig(cfg.TaskConfig): ...@@ -35,6 +35,7 @@ class SentencePredictionConfig(cfg.TaskConfig):
# At most one of `init_checkpoint` and `hub_module_url` can # At most one of `init_checkpoint` and `hub_module_url` can
# be specified. # be specified.
init_checkpoint: str = '' init_checkpoint: str = ''
init_cls_pooler: bool = False
hub_module_url: str = '' hub_module_url: str = ''
metric_type: str = 'accuracy' metric_type: str = 'accuracy'
model: bert.BertPretrainerConfig = bert.BertPretrainerConfig( model: bert.BertPretrainerConfig = bert.BertPretrainerConfig(
...@@ -58,7 +59,7 @@ class SentencePredictionTask(base_task.Task): ...@@ -58,7 +59,7 @@ class SentencePredictionTask(base_task.Task):
super(SentencePredictionTask, self).__init__(params) super(SentencePredictionTask, self).__init__(params)
if params.hub_module_url and params.init_checkpoint: if params.hub_module_url and params.init_checkpoint:
raise ValueError('At most one of `hub_module_url` and ' raise ValueError('At most one of `hub_module_url` and '
'`pretrain_checkpoint_dir` can be specified.') '`init_checkpoint` can be specified.')
if params.hub_module_url: if params.hub_module_url:
self._hub_module = hub.load(params.hub_module_url) self._hub_module = hub.load(params.hub_module_url)
else: else:
...@@ -178,13 +179,16 @@ class SentencePredictionTask(base_task.Task): ...@@ -178,13 +179,16 @@ class SentencePredictionTask(base_task.Task):
return return
pretrain2finetune_mapping = { pretrain2finetune_mapping = {
'encoder': 'encoder': model.checkpoint_items['encoder'],
model.checkpoint_items['encoder'],
'next_sentence.pooler_dense':
model.checkpoint_items['sentence_prediction.pooler_dense'],
} }
# TODO(b/160251903): Investigate why no pooler dense improves finetuning
# accuracies.
if self.task_config.init_cls_pooler:
pretrain2finetune_mapping[
'next_sentence.pooler_dense'] = model.checkpoint_items[
'sentence_prediction.pooler_dense']
ckpt = tf.train.Checkpoint(**pretrain2finetune_mapping) ckpt = tf.train.Checkpoint(**pretrain2finetune_mapping)
status = ckpt.restore(ckpt_dir_or_file) status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched() status.expect_partial().assert_existing_objects_matched()
logging.info('finished loading pretrained checkpoint from %s', logging.info('finished loading pretrained checkpoint from %s',
ckpt_dir_or_file) ckpt_dir_or_file)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment