Commit 58d19c67 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 319162455
parent 574455f5
...@@ -37,13 +37,25 @@ class Task(tf.Module): ...@@ -37,13 +37,25 @@ class Task(tf.Module):
# Special keys in train/validate step returned logs. # Special keys in train/validate step returned logs.
loss = "loss" loss = "loss"
def __init__(self, params: cfg.TaskConfig): def __init__(self, params: cfg.TaskConfig, logging_dir: str = None):
"""Task initialization.
Args:
params: cfg.TaskConfig instance.
logging_dir: a string pointing to where the model, summaries etc. will be
saved. You can also write additional stuff in this directory.
"""
self._task_config = params self._task_config = params
self._logging_dir = logging_dir
@property @property
def task_config(self) -> cfg.TaskConfig: def task_config(self) -> cfg.TaskConfig:
return self._task_config return self._task_config
@property
def logging_dir(self) -> str:
return self._logging_dir
def initialize(self, model: tf.keras.Model): def initialize(self, model: tf.keras.Model):
"""A callback function used as CheckpointManager's init_fn. """A callback function used as CheckpointManager's init_fn.
......
...@@ -54,8 +54,8 @@ class QuestionAnsweringConfig(cfg.TaskConfig): ...@@ -54,8 +54,8 @@ class QuestionAnsweringConfig(cfg.TaskConfig):
class QuestionAnsweringTask(base_task.Task): class QuestionAnsweringTask(base_task.Task):
"""Task object for question answering.""" """Task object for question answering."""
def __init__(self, params=cfg.TaskConfig): def __init__(self, params=cfg.TaskConfig, logging_dir=None):
super(QuestionAnsweringTask, self).__init__(params) super(QuestionAnsweringTask, self).__init__(params, logging_dir)
if params.hub_module_url and params.init_checkpoint: if params.hub_module_url and params.init_checkpoint:
raise ValueError('At most one of `hub_module_url` and ' raise ValueError('At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.') '`init_checkpoint` can be specified.')
...@@ -72,6 +72,10 @@ class QuestionAnsweringTask(base_task.Task): ...@@ -72,6 +72,10 @@ class QuestionAnsweringTask(base_task.Task):
raise ValueError('Unsupported tokenization method: {}'.format( raise ValueError('Unsupported tokenization method: {}'.format(
params.validation_data.tokenization)) params.validation_data.tokenization))
if params.validation_data.input_path:
self._tf_record_input_path, self._eval_examples, self._eval_features = (
self._preprocess_eval_data(params.validation_data))
def build_model(self): def build_model(self):
if self._hub_module: if self._hub_module:
encoder_network = utils.get_encoder_from_hub(self._hub_module) encoder_network = utils.get_encoder_from_hub(self._hub_module)
...@@ -107,7 +111,11 @@ class QuestionAnsweringTask(base_task.Task): ...@@ -107,7 +111,11 @@ class QuestionAnsweringTask(base_task.Task):
is_training=False, is_training=False,
version_2_with_negative=params.version_2_with_negative) version_2_with_negative=params.version_2_with_negative)
temp_file_path = params.input_preprocessed_data_path or '/tmp' temp_file_path = params.input_preprocessed_data_path or self.logging_dir
if not temp_file_path:
raise ValueError('You must specify a temporary directory, either in '
'params.input_preprocessed_data_path or logging_dir to '
'store intermediate evaluation TFRecord data.')
eval_writer = self.squad_lib.FeatureWriter( eval_writer = self.squad_lib.FeatureWriter(
filename=os.path.join(temp_file_path, 'eval.tf_record'), filename=os.path.join(temp_file_path, 'eval.tf_record'),
is_training=False) is_training=False)
...@@ -168,8 +176,7 @@ class QuestionAnsweringTask(base_task.Task): ...@@ -168,8 +176,7 @@ class QuestionAnsweringTask(base_task.Task):
if params.is_training: if params.is_training:
input_path = params.input_path input_path = params.input_path
else: else:
input_path, self._eval_examples, self._eval_features = ( input_path = self._tf_record_input_path
self._preprocess_eval_data(params))
batch_size = input_context.get_per_replica_batch_size( batch_size = input_context.get_per_replica_batch_size(
params.global_batch_size) if input_context else params.global_batch_size params.global_batch_size) if input_context else params.global_batch_size
......
...@@ -55,8 +55,8 @@ class SentencePredictionConfig(cfg.TaskConfig): ...@@ -55,8 +55,8 @@ class SentencePredictionConfig(cfg.TaskConfig):
class SentencePredictionTask(base_task.Task): class SentencePredictionTask(base_task.Task):
"""Task object for sentence_prediction.""" """Task object for sentence_prediction."""
def __init__(self, params=cfg.TaskConfig): def __init__(self, params=cfg.TaskConfig, logging_dir=None):
super(SentencePredictionTask, self).__init__(params) super(SentencePredictionTask, self).__init__(params, logging_dir)
if params.hub_module_url and params.init_checkpoint: if params.hub_module_url and params.init_checkpoint:
raise ValueError('At most one of `hub_module_url` and ' raise ValueError('At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.') '`init_checkpoint` can be specified.')
......
...@@ -75,8 +75,8 @@ def _masked_labels_and_weights(y_true): ...@@ -75,8 +75,8 @@ def _masked_labels_and_weights(y_true):
class TaggingTask(base_task.Task): class TaggingTask(base_task.Task):
"""Task object for tagging (e.g., NER or POS).""" """Task object for tagging (e.g., NER or POS)."""
def __init__(self, params=cfg.TaskConfig): def __init__(self, params=cfg.TaskConfig, logging_dir=None):
super(TaggingTask, self).__init__(params) super(TaggingTask, self).__init__(params, logging_dir)
if params.hub_module_url and params.init_checkpoint: if params.hub_module_url and params.init_checkpoint:
raise ValueError('At most one of `hub_module_url` and ' raise ValueError('At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.') '`init_checkpoint` can be specified.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment