Commit a565d720 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 323732686
parent 250701c6
......@@ -18,6 +18,7 @@ import abc
import functools
from typing import Any, Callable, Optional
from absl import logging
import six
import tensorflow as tf
......@@ -67,7 +68,19 @@ class Task(tf.Module):
Args:
model: The keras.Model built or used by this task.
"""
pass
ckpt_dir_or_file = self.task_config.init_checkpoint
logging.info("Trying to load pretrained checkpoint from %s",
ckpt_dir_or_file)
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if not ckpt_dir_or_file:
return
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info("Finished loading pretrained checkpoint from %s",
ckpt_dir_or_file)
@abc.abstractmethod
def build_model(self) -> tf.keras.Model:
......
......@@ -179,6 +179,7 @@ class TrainerConfig(base_config.Config):
max_to_keep: max checkpoints to keep.
continuous_eval_timeout: maximum number of seconds to wait between
checkpoints, if set to None, continuous eval will wait indefinitely.
This is only used continuous_train_and_eval and continuous_eval modes.
train_steps: number of train steps.
validation_steps: number of eval steps. If `None`, the entire eval dataset
is used.
......@@ -205,6 +206,7 @@ class TrainerConfig(base_config.Config):
@dataclasses.dataclass
class TaskConfig(base_config.Config):
init_checkpoint: str = ""
model: base_config.Config = None
train_data: DataConfig = DataConfig()
validation_data: DataConfig = DataConfig()
......
......@@ -14,7 +14,6 @@
# limitations under the License.
# ==============================================================================
"""Masked language task."""
from absl import logging
import dataclasses
import tensorflow as tf
......@@ -27,7 +26,6 @@ from official.nlp.data import data_loader_factory
@dataclasses.dataclass
class MaskedLMConfig(cfg.TaskConfig):
"""The model config."""
init_checkpoint: str = ''
model: bert.BertPretrainerConfig = bert.BertPretrainerConfig(cls_heads=[
bert.ClsHeadConfig(
inner_dim=768, num_classes=2, dropout_rate=0.1, name='next_sentence')
......@@ -174,17 +172,3 @@ class MaskedLMTask(base_task.Task):
aux_losses=model.losses)
self.process_metrics(metrics, inputs, outputs)
return {self.loss: loss}
def initialize(self, model: tf.keras.Model):
ckpt_dir_or_file = self.task_config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if not ckpt_dir_or_file:
return
# Restoring all modules defined by the model, e.g. encoder, masked_lm and
# cls pooler. The best initialization may vary case by case.
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
......@@ -290,17 +290,3 @@ class QuestionAnsweringTask(base_task.Task):
eval_metrics = {'exact_match': eval_metrics['exact_match'],
'final_f1': eval_metrics['final_f1']}
return eval_metrics
def initialize(self, model):
"""Load a pretrained checkpoint (if exists) and then train from iter 0."""
ckpt_dir_or_file = self.task_config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if not ckpt_dir_or_file:
return
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
......@@ -14,7 +14,6 @@
# limitations under the License.
# ==============================================================================
"""Tagging (e.g., NER/POS) task."""
import logging
from typing import List, Optional, Tuple
import dataclasses
......@@ -215,20 +214,6 @@ class TaggingTask(base_task.Task):
seqeval_metrics.accuracy_score(label_class, predict_class),
}
def initialize(self, model):
"""Load a pretrained checkpoint (if exists) and then train from iter 0."""
ckpt_dir_or_file = self.task_config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if not ckpt_dir_or_file:
return
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
def predict(task: TaggingTask, params: cfg.DataConfig,
model: tf.keras.Model) -> Tuple[List[List[int]], List[int]]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment