Commit 58d19c67 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 319162455
parent 574455f5
......@@ -37,13 +37,25 @@ class Task(tf.Module):
# Special keys in train/validate step returned logs.
loss = "loss"
def __init__(self, params: cfg.TaskConfig):
def __init__(self, params: cfg.TaskConfig, logging_dir: str = None):
"""Task initialization.
Args:
params: cfg.TaskConfig instance.
logging_dir: a string pointing to where the model, summaries etc. will be
saved. You can also write additional stuff in this directory.
"""
self._task_config = params
self._logging_dir = logging_dir
@property
def task_config(self) -> cfg.TaskConfig:
return self._task_config
@property
def logging_dir(self) -> str:
return self._logging_dir
def initialize(self, model: tf.keras.Model):
"""A callback function used as CheckpointManager's init_fn.
......
......@@ -54,8 +54,8 @@ class QuestionAnsweringConfig(cfg.TaskConfig):
class QuestionAnsweringTask(base_task.Task):
"""Task object for question answering."""
def __init__(self, params=cfg.TaskConfig):
super(QuestionAnsweringTask, self).__init__(params)
def __init__(self, params=cfg.TaskConfig, logging_dir=None):
super(QuestionAnsweringTask, self).__init__(params, logging_dir)
if params.hub_module_url and params.init_checkpoint:
raise ValueError('At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.')
......@@ -72,6 +72,10 @@ class QuestionAnsweringTask(base_task.Task):
raise ValueError('Unsupported tokenization method: {}'.format(
params.validation_data.tokenization))
if params.validation_data.input_path:
self._tf_record_input_path, self._eval_examples, self._eval_features = (
self._preprocess_eval_data(params.validation_data))
def build_model(self):
if self._hub_module:
encoder_network = utils.get_encoder_from_hub(self._hub_module)
......@@ -107,7 +111,11 @@ class QuestionAnsweringTask(base_task.Task):
is_training=False,
version_2_with_negative=params.version_2_with_negative)
temp_file_path = params.input_preprocessed_data_path or '/tmp'
temp_file_path = params.input_preprocessed_data_path or self.logging_dir
if not temp_file_path:
raise ValueError('You must specify a temporary directory, either in '
'params.input_preprocessed_data_path or logging_dir to '
'store intermediate evaluation TFRecord data.')
eval_writer = self.squad_lib.FeatureWriter(
filename=os.path.join(temp_file_path, 'eval.tf_record'),
is_training=False)
......@@ -168,8 +176,7 @@ class QuestionAnsweringTask(base_task.Task):
if params.is_training:
input_path = params.input_path
else:
input_path, self._eval_examples, self._eval_features = (
self._preprocess_eval_data(params))
input_path = self._tf_record_input_path
batch_size = input_context.get_per_replica_batch_size(
params.global_batch_size) if input_context else params.global_batch_size
......
......@@ -55,8 +55,8 @@ class SentencePredictionConfig(cfg.TaskConfig):
class SentencePredictionTask(base_task.Task):
"""Task object for sentence_prediction."""
def __init__(self, params=cfg.TaskConfig):
super(SentencePredictionTask, self).__init__(params)
def __init__(self, params=cfg.TaskConfig, logging_dir=None):
super(SentencePredictionTask, self).__init__(params, logging_dir)
if params.hub_module_url and params.init_checkpoint:
raise ValueError('At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.')
......
......@@ -75,8 +75,8 @@ def _masked_labels_and_weights(y_true):
class TaggingTask(base_task.Task):
"""Task object for tagging (e.g., NER or POS)."""
def __init__(self, params=cfg.TaskConfig):
super(TaggingTask, self).__init__(params)
def __init__(self, params=cfg.TaskConfig, logging_dir=None):
super(TaggingTask, self).__init__(params, logging_dir)
if params.hub_module_url and params.init_checkpoint:
raise ValueError('At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment