Commit 143fd0b6 authored by Le Hou's avatar Le Hou Committed by A. Unique TensorFlower
Browse files

Minor bug fixes

PiperOrigin-RevId: 422637653
parent 871c4e0a
...@@ -101,9 +101,11 @@ class Task(tf.Module, metaclass=abc.ABCMeta): ...@@ -101,9 +101,11 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
ckpt_dir_or_file = self.task_config.init_checkpoint ckpt_dir_or_file = self.task_config.init_checkpoint
logging.info("Trying to load pretrained checkpoint from %s", logging.info("Trying to load pretrained checkpoint from %s",
ckpt_dir_or_file) ckpt_dir_or_file)
if tf.io.gfile.isdir(ckpt_dir_or_file): if ckpt_dir_or_file and tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file) ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if not ckpt_dir_or_file: if not ckpt_dir_or_file:
logging.info("No checkpoint file found from %s. Will not load.",
ckpt_dir_or_file)
return return
if hasattr(model, "checkpoint_items"): if hasattr(model, "checkpoint_items"):
......
...@@ -187,9 +187,13 @@ class DualEncoderTask(base_task.Task): ...@@ -187,9 +187,13 @@ class DualEncoderTask(base_task.Task):
def initialize(self, model): def initialize(self, model):
"""Load a pretrained checkpoint (if exists) and then train from iter 0.""" """Load a pretrained checkpoint (if exists) and then train from iter 0."""
ckpt_dir_or_file = self.task_config.init_checkpoint ckpt_dir_or_file = self.task_config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file): logging.info('Trying to load pretrained checkpoint from %s',
ckpt_dir_or_file)
if ckpt_dir_or_file and tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file) ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if not ckpt_dir_or_file: if not ckpt_dir_or_file:
logging.info('No checkpoint file found from %s. Will not load.',
ckpt_dir_or_file)
return return
pretrain2finetune_mapping = { pretrain2finetune_mapping = {
......
...@@ -223,10 +223,14 @@ class SentencePredictionTask(base_task.Task): ...@@ -223,10 +223,14 @@ class SentencePredictionTask(base_task.Task):
def initialize(self, model): def initialize(self, model):
"""Load a pretrained checkpoint (if exists) and then train from iter 0.""" """Load a pretrained checkpoint (if exists) and then train from iter 0."""
ckpt_dir_or_file = self.task_config.init_checkpoint ckpt_dir_or_file = self.task_config.init_checkpoint
logging.info('Trying to load pretrained checkpoint from %s',
ckpt_dir_or_file)
if ckpt_dir_or_file and tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if not ckpt_dir_or_file: if not ckpt_dir_or_file:
logging.info('No checkpoint file found from %s. Will not load.',
ckpt_dir_or_file)
return return
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
pretrain2finetune_mapping = { pretrain2finetune_mapping = {
'encoder': model.checkpoint_items['encoder'], 'encoder': model.checkpoint_items['encoder'],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment