Commit 646c5755 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 337927475
parent b18f2868
......@@ -15,8 +15,7 @@
# ==============================================================================
"""Defines the base task abstraction."""
import abc
import functools
from typing import Any, Callable, Optional
from typing import Optional
from absl import logging
import tensorflow as tf
......@@ -25,9 +24,9 @@ import tensorflow as tf
class Task(tf.Module, metaclass=abc.ABCMeta):
"""A single-replica view of training procedure.
Tasks provide artifacts for training/evalution procedures, including
loading/iterating over Datasets, initializing the model, calculating the loss
and customized metrics with reduction.
Tasks provide artifacts for training/validation procedures, including
loading/iterating over Datasets, training/validation steps, calculating the
loss and customized metrics with reduction.
"""
# Special keys in train/validate step returned logs.
......@@ -91,41 +90,6 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
A model instance.
"""
def compile_model(self,
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
loss=None,
train_step: Optional[Callable[..., Any]] = None,
validation_step: Optional[Callable[..., Any]] = None,
**kwargs) -> tf.keras.Model:
"""Compiles the model with objects created by the task.
The method should not be used in any customized training implementation.
Args:
model: a keras.Model.
optimizer: the keras optimizer.
loss: a callable/list of losses.
train_step: optional train step function defined by the task.
validation_step: optional validation_step step function defined by the
task.
**kwargs: other kwargs consumed by keras.Model compile().
Returns:
a compiled keras.Model.
"""
if bool(loss is None) == bool(train_step is None):
raise ValueError("`loss` and `train_step` should be exclusive to "
"each other.")
model.compile(optimizer=optimizer, loss=loss, **kwargs)
if train_step:
model.train_step = functools.partial(
train_step, model=model, optimizer=model.optimizer)
if validation_step:
model.test_step = functools.partial(validation_step, model=model)
return model
@abc.abstractmethod
def build_inputs(self,
params,
......@@ -244,9 +208,9 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
logs = {self.loss: loss}
if metrics:
self.process_metrics(metrics, labels, outputs)
logs.update({m.name: m.result() for m in metrics})
elif model.compiled_metrics:
if model.compiled_metrics:
self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
logs.update({m.name: m.result() for m in metrics or []})
logs.update({m.name: m.result() for m in model.metrics})
return logs
......@@ -273,9 +237,9 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
logs = {self.loss: loss}
if metrics:
self.process_metrics(metrics, labels, outputs)
logs.update({m.name: m.result() for m in metrics})
elif model.compiled_metrics:
if model.compiled_metrics:
self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
logs.update({m.name: m.result() for m in metrics or []})
logs.update({m.name: m.result() for m in model.metrics})
return logs
......
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow_models.core.base_task."""
import functools
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.utils.testing import mock_task
def all_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
mode='eager',
)
class TaskKerasTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(all_strategy_combinations())
def test_task_with_step_override(self, distribution):
with distribution.scope():
task = mock_task.MockTask()
model = task.build_model()
model = task.compile_model(
model,
optimizer=tf.keras.optimizers.SGD(learning_rate=1e-3),
metrics=task.build_metrics(),
train_step=task.train_step,
validation_step=task.validation_step)
dataset = task.build_inputs(params=None)
logs = model.fit(dataset, epochs=1, steps_per_epoch=2)
self.assertIn('loss', logs.history)
self.assertIn('acc', logs.history)
# Without specifying metrics through compile.
with distribution.scope():
train_metrics = task.build_metrics(training=True)
val_metrics = task.build_metrics(training=False)
model = task.build_model()
model = task.compile_model(
model,
optimizer=tf.keras.optimizers.SGD(learning_rate=1e-3),
train_step=functools.partial(task.train_step, metrics=train_metrics),
validation_step=functools.partial(
task.validation_step, metrics=val_metrics))
logs = model.fit(dataset, epochs=1, steps_per_epoch=2)
self.assertIn('loss', logs.history)
self.assertIn('acc', logs.history)
def test_task_with_fit(self):
task = mock_task.MockTask()
model = task.build_model()
model = task.compile_model(
model,
optimizer=tf.keras.optimizers.SGD(learning_rate=1e-3),
loss=tf.keras.losses.CategoricalCrossentropy(),
metrics=task.build_metrics())
dataset = task.build_inputs(params=None)
logs = model.fit(dataset, epochs=1, steps_per_epoch=2)
self.assertIn('loss', logs.history)
self.assertIn('acc', logs.history)
self.assertLen(model.evaluate(dataset, steps=1), 2)
def test_task_invalid_compile(self):
task = mock_task.MockTask()
model = task.build_model()
with self.assertRaises(ValueError):
_ = task.compile_model(
model,
optimizer=tf.keras.optimizers.SGD(learning_rate=1e-3),
loss=tf.keras.losses.CategoricalCrossentropy(),
metrics=task.build_metrics(),
train_step=task.train_step)
if __name__ == '__main__':
tf.test.main()
......@@ -78,8 +78,7 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
with distribution.scope():
trainer = self.create_test_trainer(self._config)
logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('validation_loss', logs)
self.assertEqual(logs['acc'], 5. * distribution.num_replicas_in_sync)
self.assertEqual(logs['counter'], 5. * distribution.num_replicas_in_sync)
@combinations.generate(
combinations.combine(
......
......@@ -131,24 +131,6 @@ class QuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase):
version_2_with_negative))
self._run_task(config)
def test_task_with_fit(self):
config = question_answering.QuestionAnsweringConfig(
model=question_answering.ModelConfig(encoder=self._encoder_config),
train_data=self._train_data_config,
validation_data=self._get_validation_data_config())
task = question_answering.QuestionAnsweringTask(config)
model = task.build_model()
model = task.compile_model(
model,
optimizer=tf.keras.optimizers.SGD(lr=0.1),
train_step=task.train_step,
metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy")])
dataset = task.build_inputs(config.train_data)
logs = model.fit(dataset, epochs=1, steps_per_epoch=2)
self.assertIn("loss", logs.history)
self.assertIn("start_positions_accuracy", logs.history)
self.assertIn("end_positions_accuracy", logs.history)
def _export_bert_tfhub(self):
bert_config = configs.BertConfig(
vocab_size=30522,
......
......@@ -210,20 +210,6 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
outputs = self._run_task(config)
self.assertEqual(outputs["sentence_prediction"].shape.as_list(), [8, 1])
def test_task_with_fit(self):
config = sentence_prediction.SentencePredictionConfig(
model=self.get_model_config(2), train_data=self._train_data_config)
task = sentence_prediction.SentencePredictionTask(config)
model = task.build_model()
model = task.compile_model(
model,
optimizer=tf.keras.optimizers.SGD(lr=0.1),
train_step=task.train_step,
metrics=task.build_metrics())
dataset = task.build_inputs(config.train_data)
logs = model.fit(dataset, epochs=1, steps_per_epoch=2)
self.assertIn("loss", logs.history)
def _export_bert_tfhub(self):
bert_config = configs.BertConfig(
vocab_size=30522,
......
......@@ -96,24 +96,6 @@ class TaggingTest(tf.test.TestCase):
task.validation_step(next(iterator), model, metrics=metrics)
task.initialize(model)
def test_task_with_fit(self):
config = tagging.TaggingConfig(
model=tagging.ModelConfig(encoder=self._encoder_config),
train_data=self._train_data_config,
class_names=["O", "B-PER", "I-PER"])
task = tagging.TaggingTask(config)
model = task.build_model()
model = task.compile_model(
model,
optimizer=tf.keras.optimizers.SGD(lr=0.1),
train_step=task.train_step,
metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy")])
dataset = task.build_inputs(config.train_data)
logs = model.fit(dataset, epochs=1, steps_per_epoch=2)
self.assertIn("loss", logs.history)
self.assertIn("accuracy", logs.history)
def _export_bert_tfhub(self):
bert_config = configs.BertConfig(
vocab_size=30522,
......
......@@ -51,7 +51,9 @@ class MockTask(base_task.Task):
def build_model(self, *arg, **kwargs):
inputs = tf.keras.layers.Input(shape=(2,), name="random", dtype=tf.float32)
outputs = tf.keras.layers.Dense(1)(inputs)
outputs = tf.keras.layers.Dense(
1, bias_initializer=tf.keras.initializers.Ones())(
inputs)
network = tf.keras.Model(inputs=inputs, outputs=outputs)
return MockModel(network)
......@@ -59,6 +61,11 @@ class MockTask(base_task.Task):
del training
return [tf.keras.metrics.Accuracy(name="acc")]
def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
logs = super().validation_step(inputs, model, metrics)
logs["counter"] = tf.ones((1,), dtype=tf.float32)
return logs
def build_inputs(self, params):
def generate_data(_):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment