"llm/vscode:/vscode.git/clone" did not exist on "9d91e5e5875e2b2f8605ef15a7da9a616cb05171"
Commit a565d720 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 323732686
parent 250701c6
...@@ -18,6 +18,7 @@ import abc ...@@ -18,6 +18,7 @@ import abc
import functools import functools
from typing import Any, Callable, Optional from typing import Any, Callable, Optional
from absl import logging
import six import six
import tensorflow as tf import tensorflow as tf
...@@ -67,7 +68,19 @@ class Task(tf.Module): ...@@ -67,7 +68,19 @@ class Task(tf.Module):
Args: Args:
model: The keras.Model built or used by this task. model: The keras.Model built or used by this task.
""" """
pass ckpt_dir_or_file = self.task_config.init_checkpoint
logging.info("Trying to load pretrained checkpoint from %s",
ckpt_dir_or_file)
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if not ckpt_dir_or_file:
return
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info("Finished loading pretrained checkpoint from %s",
ckpt_dir_or_file)
@abc.abstractmethod @abc.abstractmethod
def build_model(self) -> tf.keras.Model: def build_model(self) -> tf.keras.Model:
......
...@@ -179,6 +179,7 @@ class TrainerConfig(base_config.Config): ...@@ -179,6 +179,7 @@ class TrainerConfig(base_config.Config):
max_to_keep: max checkpoints to keep. max_to_keep: max checkpoints to keep.
continuous_eval_timeout: maximum number of seconds to wait between continuous_eval_timeout: maximum number of seconds to wait between
checkpoints, if set to None, continuous eval will wait indefinitely. checkpoints, if set to None, continuous eval will wait indefinitely.
This is only used continuous_train_and_eval and continuous_eval modes.
train_steps: number of train steps. train_steps: number of train steps.
validation_steps: number of eval steps. If `None`, the entire eval dataset validation_steps: number of eval steps. If `None`, the entire eval dataset
is used. is used.
...@@ -205,6 +206,7 @@ class TrainerConfig(base_config.Config): ...@@ -205,6 +206,7 @@ class TrainerConfig(base_config.Config):
@dataclasses.dataclass @dataclasses.dataclass
class TaskConfig(base_config.Config): class TaskConfig(base_config.Config):
init_checkpoint: str = ""
model: base_config.Config = None model: base_config.Config = None
train_data: DataConfig = DataConfig() train_data: DataConfig = DataConfig()
validation_data: DataConfig = DataConfig() validation_data: DataConfig = DataConfig()
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Masked language task.""" """Masked language task."""
from absl import logging
import dataclasses import dataclasses
import tensorflow as tf import tensorflow as tf
...@@ -27,7 +26,6 @@ from official.nlp.data import data_loader_factory ...@@ -27,7 +26,6 @@ from official.nlp.data import data_loader_factory
@dataclasses.dataclass @dataclasses.dataclass
class MaskedLMConfig(cfg.TaskConfig): class MaskedLMConfig(cfg.TaskConfig):
"""The model config.""" """The model config."""
init_checkpoint: str = ''
model: bert.BertPretrainerConfig = bert.BertPretrainerConfig(cls_heads=[ model: bert.BertPretrainerConfig = bert.BertPretrainerConfig(cls_heads=[
bert.ClsHeadConfig( bert.ClsHeadConfig(
inner_dim=768, num_classes=2, dropout_rate=0.1, name='next_sentence') inner_dim=768, num_classes=2, dropout_rate=0.1, name='next_sentence')
...@@ -174,17 +172,3 @@ class MaskedLMTask(base_task.Task): ...@@ -174,17 +172,3 @@ class MaskedLMTask(base_task.Task):
aux_losses=model.losses) aux_losses=model.losses)
self.process_metrics(metrics, inputs, outputs) self.process_metrics(metrics, inputs, outputs)
return {self.loss: loss} return {self.loss: loss}
def initialize(self, model: tf.keras.Model):
ckpt_dir_or_file = self.task_config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if not ckpt_dir_or_file:
return
# Restoring all modules defined by the model, e.g. encoder, masked_lm and
# cls pooler. The best initialization may vary case by case.
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
...@@ -290,17 +290,3 @@ class QuestionAnsweringTask(base_task.Task): ...@@ -290,17 +290,3 @@ class QuestionAnsweringTask(base_task.Task):
eval_metrics = {'exact_match': eval_metrics['exact_match'], eval_metrics = {'exact_match': eval_metrics['exact_match'],
'final_f1': eval_metrics['final_f1']} 'final_f1': eval_metrics['final_f1']}
return eval_metrics return eval_metrics
def initialize(self, model):
"""Load a pretrained checkpoint (if exists) and then train from iter 0."""
ckpt_dir_or_file = self.task_config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if not ckpt_dir_or_file:
return
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Tagging (e.g., NER/POS) task.""" """Tagging (e.g., NER/POS) task."""
import logging
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import dataclasses import dataclasses
...@@ -215,20 +214,6 @@ class TaggingTask(base_task.Task): ...@@ -215,20 +214,6 @@ class TaggingTask(base_task.Task):
seqeval_metrics.accuracy_score(label_class, predict_class), seqeval_metrics.accuracy_score(label_class, predict_class),
} }
def initialize(self, model):
"""Load a pretrained checkpoint (if exists) and then train from iter 0."""
ckpt_dir_or_file = self.task_config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if not ckpt_dir_or_file:
return
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
def predict(task: TaggingTask, params: cfg.DataConfig, def predict(task: TaggingTask, params: cfg.DataConfig,
model: tf.keras.Model) -> Tuple[List[List[int]], List[int]]: model: tf.keras.Model) -> Tuple[List[List[int]], List[int]]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment