Commit 4dc945c0 authored by Will Cromar's avatar Will Cromar Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 339131213
parent 3421f8c6
...@@ -32,7 +32,7 @@ def all_strategy_combinations(): ...@@ -32,7 +32,7 @@ def all_strategy_combinations():
return combinations.combine( return combinations.combine(
distribution=[ distribution=[
strategy_combinations.default_strategy, strategy_combinations.default_strategy,
strategy_combinations.tpu_strategy, strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu, strategy_combinations.one_device_strategy_gpu,
], ],
mode='eager', mode='eager',
......
...@@ -65,7 +65,7 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase): ...@@ -65,7 +65,7 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase):
combinations.combine( combinations.combine(
distribution_strategy=[ distribution_strategy=[
strategy_combinations.default_strategy, strategy_combinations.default_strategy,
strategy_combinations.tpu_strategy, strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu, strategy_combinations.one_device_strategy_gpu,
], ],
mode='eager', mode='eager',
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import os import os
from absl import logging from absl import logging
from absl.testing import flagsaver
from absl.testing import parameterized from absl.testing import parameterized
from absl.testing.absltest import mock from absl.testing.absltest import mock
import numpy as np import numpy as np
...@@ -24,14 +25,18 @@ import tensorflow as tf ...@@ -24,14 +25,18 @@ import tensorflow as tf
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_combinations
from official.nlp.bert import common_flags
from official.nlp.bert import model_training_utils from official.nlp.bert import model_training_utils
common_flags.define_common_bert_flags()
def eager_strategy_combinations(): def eager_strategy_combinations():
return combinations.combine( return combinations.combine(
distribution=[ distribution=[
strategy_combinations.default_strategy, strategy_combinations.default_strategy,
strategy_combinations.tpu_strategy, strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu, strategy_combinations.one_device_strategy_gpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.mirrored_strategy_with_two_gpus, strategy_combinations.mirrored_strategy_with_two_gpus,
...@@ -158,6 +163,7 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase): ...@@ -158,6 +163,7 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
super(ModelTrainingUtilsTest, self).setUp() super(ModelTrainingUtilsTest, self).setUp()
self._model_fn = create_model_fn(input_shape=[128], num_classes=3) self._model_fn = create_model_fn(input_shape=[128], num_classes=3)
@flagsaver.flagsaver
def run_training(self, strategy, model_dir, steps_per_loop, run_eagerly): def run_training(self, strategy, model_dir, steps_per_loop, run_eagerly):
input_fn = create_fake_data_input_fn( input_fn = create_fake_data_input_fn(
batch_size=8, features_shape=[128], num_classes=3) batch_size=8, features_shape=[128], num_classes=3)
...@@ -180,7 +186,7 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase): ...@@ -180,7 +186,7 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(eager_strategy_combinations()) @combinations.generate(eager_strategy_combinations())
def test_train_eager_single_step(self, distribution): def test_train_eager_single_step(self, distribution):
model_dir = self.get_temp_dir() model_dir = self.create_tempdir().full_path
if isinstance(distribution, tf.distribute.experimental.TPUStrategy): if isinstance(distribution, tf.distribute.experimental.TPUStrategy):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
self.run_training( self.run_training(
...@@ -191,7 +197,7 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase): ...@@ -191,7 +197,7 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(eager_gpu_strategy_combinations()) @combinations.generate(eager_gpu_strategy_combinations())
def test_train_eager_mixed_precision(self, distribution): def test_train_eager_mixed_precision(self, distribution):
model_dir = self.get_temp_dir() model_dir = self.create_tempdir().full_path
policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16') policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16')
tf.keras.mixed_precision.experimental.set_policy(policy) tf.keras.mixed_precision.experimental.set_policy(policy)
self._model_fn = create_model_fn( self._model_fn = create_model_fn(
...@@ -201,7 +207,7 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase): ...@@ -201,7 +207,7 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(eager_strategy_combinations()) @combinations.generate(eager_strategy_combinations())
def test_train_check_artifacts(self, distribution): def test_train_check_artifacts(self, distribution):
model_dir = self.get_temp_dir() model_dir = self.create_tempdir().full_path
self.run_training( self.run_training(
distribution, model_dir, steps_per_loop=10, run_eagerly=False) distribution, model_dir, steps_per_loop=10, run_eagerly=False)
...@@ -245,7 +251,7 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase): ...@@ -245,7 +251,7 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(eager_strategy_combinations()) @combinations.generate(eager_strategy_combinations())
def test_train_check_callbacks(self, distribution): def test_train_check_callbacks(self, distribution):
model_dir = self.get_temp_dir() model_dir = self.create_tempdir().full_path
callback = RecordingCallback() callback = RecordingCallback()
callbacks = [callback] callbacks = [callback]
input_fn = create_fake_data_input_fn( input_fn = create_fake_data_input_fn(
...@@ -296,7 +302,7 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase): ...@@ -296,7 +302,7 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
new_callable=mock.PropertyMock, return_value=False), \ new_callable=mock.PropertyMock, return_value=False), \
mock.patch.object(extended.__class__, 'should_save_summary', mock.patch.object(extended.__class__, 'should_save_summary',
new_callable=mock.PropertyMock, return_value=False): new_callable=mock.PropertyMock, return_value=False):
model_dir = self.get_temp_dir() model_dir = self.create_tempdir().full_path
self.run_training( self.run_training(
distribution, model_dir, steps_per_loop=10, run_eagerly=False) distribution, model_dir, steps_per_loop=10, run_eagerly=False)
self.assertEmpty(tf.io.gfile.listdir(model_dir)) self.assertEmpty(tf.io.gfile.listdir(model_dir))
......
...@@ -61,7 +61,7 @@ class Seq2SeqTransformerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -61,7 +61,7 @@ class Seq2SeqTransformerTest(tf.test.TestCase, parameterized.TestCase):
combinations.combine( combinations.combine(
distribution=[ distribution=[
strategy_combinations.default_strategy, strategy_combinations.default_strategy,
strategy_combinations.tpu_strategy, strategy_combinations.cloud_tpu_strategy,
], ],
mode="eager")) mode="eager"))
def test_create_model_with_ds(self, distribution): def test_create_model_with_ds(self, distribution):
......
...@@ -34,7 +34,7 @@ def all_strategy_combinations(): ...@@ -34,7 +34,7 @@ def all_strategy_combinations():
return combinations.combine( return combinations.combine(
distribution=[ distribution=[
strategy_combinations.default_strategy, strategy_combinations.default_strategy,
strategy_combinations.tpu_strategy, strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu, strategy_combinations.one_device_strategy_gpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.mirrored_strategy_with_two_gpus, strategy_combinations.mirrored_strategy_with_two_gpus,
......
...@@ -38,7 +38,7 @@ def all_strategy_combinations(): ...@@ -38,7 +38,7 @@ def all_strategy_combinations():
strategy_combinations.one_device_strategy, strategy_combinations.one_device_strategy,
strategy_combinations.one_device_strategy_gpu, strategy_combinations.one_device_strategy_gpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy, strategy_combinations.cloud_tpu_strategy,
], ],
mode="eager", mode="eager",
) )
......
...@@ -51,7 +51,7 @@ class ResNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -51,7 +51,7 @@ class ResNetTest(parameterized.TestCase, tf.test.TestCase):
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(
strategy=[ strategy=[
strategy_combinations.tpu_strategy, strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu, strategy_combinations.one_device_strategy_gpu,
], ],
use_sync_bn=[False, True], use_sync_bn=[False, True],
......
...@@ -68,7 +68,7 @@ class ResNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -68,7 +68,7 @@ class ResNetTest(parameterized.TestCase, tf.test.TestCase):
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(
strategy=[ strategy=[
strategy_combinations.tpu_strategy, strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu, strategy_combinations.one_device_strategy_gpu,
], ],
use_sync_bn=[False, True], use_sync_bn=[False, True],
......
...@@ -126,7 +126,7 @@ class ClassificationNetworkTest(parameterized.TestCase, tf.test.TestCase): ...@@ -126,7 +126,7 @@ class ClassificationNetworkTest(parameterized.TestCase, tf.test.TestCase):
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(
strategy=[ strategy=[
strategy_combinations.tpu_strategy, strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu, strategy_combinations.one_device_strategy_gpu,
], ],
use_sync_bn=[False, True], use_sync_bn=[False, True],
......
...@@ -30,7 +30,7 @@ def distribution_strategy_combinations() -> Iterable[Tuple[Any, ...]]: ...@@ -30,7 +30,7 @@ def distribution_strategy_combinations() -> Iterable[Tuple[Any, ...]]:
return combinations.combine( return combinations.combine(
distribution=[ distribution=[
strategy_combinations.default_strategy, strategy_combinations.default_strategy,
strategy_combinations.tpu_strategy, strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu, strategy_combinations.one_device_strategy_gpu,
], ],
) )
......
...@@ -92,7 +92,7 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -92,7 +92,7 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(
strategy=[ strategy=[
strategy_combinations.tpu_strategy, strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu, strategy_combinations.one_device_strategy_gpu,
], ],
image_size=[(128, 128),], image_size=[(128, 128),],
......
...@@ -19,7 +19,6 @@ from __future__ import absolute_import ...@@ -19,7 +19,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import copy
import functools import functools
import json import json
...@@ -29,6 +28,7 @@ import sys ...@@ -29,6 +28,7 @@ import sys
from typing import Any, Callable, Iterable, Mapping, MutableMapping, Optional, Tuple from typing import Any, Callable, Iterable, Mapping, MutableMapping, Optional, Tuple
from absl import flags from absl import flags
from absl.testing import flagsaver
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
...@@ -36,9 +36,7 @@ from tensorflow.python.distribute import combinations ...@@ -36,9 +36,7 @@ from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_combinations
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.vision.image_classification import classifier_trainer from official.vision.image_classification import classifier_trainer
from official.vision.image_classification import dataset_factory
from official.vision.image_classification import test_utils
from official.vision.image_classification.configs import base_configs
classifier_trainer.define_classifier_flags() classifier_trainer.define_classifier_flags()
...@@ -48,7 +46,7 @@ def distribution_strategy_combinations() -> Iterable[Tuple[Any, ...]]: ...@@ -48,7 +46,7 @@ def distribution_strategy_combinations() -> Iterable[Tuple[Any, ...]]:
return combinations.combine( return combinations.combine(
distribution=[ distribution=[
strategy_combinations.default_strategy, strategy_combinations.default_strategy,
strategy_combinations.tpu_strategy, strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu, strategy_combinations.one_device_strategy_gpu,
strategy_combinations.mirrored_strategy_with_two_gpus, strategy_combinations.mirrored_strategy_with_two_gpus,
], ],
...@@ -99,32 +97,7 @@ def basic_params_override(dtype: str = 'float32') -> MutableMapping[str, Any]: ...@@ -99,32 +97,7 @@ def basic_params_override(dtype: str = 'float32') -> MutableMapping[str, Any]:
} }
def get_trivial_model(num_classes: int) -> tf.keras.Model: @flagsaver.flagsaver
"""Creates and compiles trivial model for ImageNet dataset."""
model = test_utils.trivial_model(num_classes=num_classes)
lr = 0.01
optimizer = tf.keras.optimizers.SGD(learning_rate=lr)
loss_obj = tf.keras.losses.SparseCategoricalCrossentropy()
model.compile(optimizer=optimizer, loss=loss_obj, run_eagerly=True)
return model
def get_trivial_data() -> tf.data.Dataset:
"""Gets trivial data in the ImageNet size."""
def generate_data(_) -> tf.data.Dataset:
image = tf.zeros(shape=(224, 224, 3), dtype=tf.float32)
label = tf.zeros([1], dtype=tf.int32)
return image, label
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
dataset = dataset.map(
generate_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.prefetch(buffer_size=1).batch(1)
return dataset
def run_end_to_end(main: Callable[[Any], None], def run_end_to_end(main: Callable[[Any], None],
extra_flags: Optional[Iterable[str]] = None, extra_flags: Optional[Iterable[str]] = None,
model_dir: Optional[str] = None): model_dir: Optional[str] = None):
...@@ -153,7 +126,7 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase): ...@@ -153,7 +126,7 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
# Some parameters are not defined as flags (e.g. cannot run # Some parameters are not defined as flags (e.g. cannot run
# classifier_train.py --batch_size=...) by design, so use # classifier_train.py --batch_size=...) by design, so use
# "--params_override=..." instead # "--params_override=..." instead
model_dir = self.get_temp_dir() model_dir = self.create_tempdir().full_path
base_flags = [ base_flags = [
'--data_dir=not_used', '--data_dir=not_used',
'--model_type=' + model, '--model_type=' + model,
...@@ -187,7 +160,7 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase): ...@@ -187,7 +160,7 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
# Some parameters are not defined as flags (e.g. cannot run # Some parameters are not defined as flags (e.g. cannot run
# classifier_train.py --batch_size=...) by design, so use # classifier_train.py --batch_size=...) by design, so use
# "--params_override=..." instead # "--params_override=..." instead
model_dir = self.get_temp_dir() model_dir = self.create_tempdir().full_path
base_flags = [ base_flags = [
'--data_dir=not_used', '--data_dir=not_used',
'--model_type=' + model, '--model_type=' + model,
...@@ -217,7 +190,7 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase): ...@@ -217,7 +190,7 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(
distribution=[ distribution=[
strategy_combinations.tpu_strategy, strategy_combinations.cloud_tpu_strategy,
], ],
model=[ model=[
'efficientnet', 'efficientnet',
...@@ -232,7 +205,7 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase): ...@@ -232,7 +205,7 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
# Some parameters are not defined as flags (e.g. cannot run # Some parameters are not defined as flags (e.g. cannot run
# classifier_train.py --batch_size=...) by design, so use # classifier_train.py --batch_size=...) by design, so use
# "--params_override=..." instead # "--params_override=..." instead
model_dir = self.get_temp_dir() model_dir = self.create_tempdir().full_path
base_flags = [ base_flags = [
'--data_dir=not_used', '--data_dir=not_used',
'--model_type=' + model, '--model_type=' + model,
...@@ -251,7 +224,7 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase): ...@@ -251,7 +224,7 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(distribution_strategy_combinations()) @combinations.generate(distribution_strategy_combinations())
def test_end_to_end_invalid_mode(self, distribution, model, dataset): def test_end_to_end_invalid_mode(self, distribution, model, dataset):
"""Test the Keras EfficientNet model with `strategy`.""" """Test the Keras EfficientNet model with `strategy`."""
model_dir = self.get_temp_dir() model_dir = self.create_tempdir().full_path
extra_flags = [ extra_flags = [
'--data_dir=not_used', '--data_dir=not_used',
'--mode=invalid_mode', '--mode=invalid_mode',
...@@ -266,111 +239,5 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase): ...@@ -266,111 +239,5 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
run_end_to_end(main=run, extra_flags=extra_flags, model_dir=model_dir) run_end_to_end(main=run, extra_flags=extra_flags, model_dir=model_dir)
class UtilTests(parameterized.TestCase, tf.test.TestCase):
"""Tests for individual utility functions within classifier_trainer.py."""
@parameterized.named_parameters(
('efficientnet-b0', 'efficientnet', 'efficientnet-b0', 224),
('efficientnet-b1', 'efficientnet', 'efficientnet-b1', 240),
('efficientnet-b2', 'efficientnet', 'efficientnet-b2', 260),
('efficientnet-b3', 'efficientnet', 'efficientnet-b3', 300),
('efficientnet-b4', 'efficientnet', 'efficientnet-b4', 380),
('efficientnet-b5', 'efficientnet', 'efficientnet-b5', 456),
('efficientnet-b6', 'efficientnet', 'efficientnet-b6', 528),
('efficientnet-b7', 'efficientnet', 'efficientnet-b7', 600),
('resnet', 'resnet', '', None),
)
def test_get_model_size(self, model, model_name, expected):
config = base_configs.ExperimentConfig(
model_name=model,
model=base_configs.ModelConfig(
model_params={
'model_name': model_name,
},))
size = classifier_trainer.get_image_size_from_model(config)
self.assertEqual(size, expected)
@parameterized.named_parameters(
('dynamic', 'dynamic', None, 'dynamic'),
('scalar', 128., None, 128.),
('float32', None, 'float32', 1),
('float16', None, 'float16', 128),
)
def test_get_loss_scale(self, loss_scale, dtype, expected):
config = base_configs.ExperimentConfig(
runtime=base_configs.RuntimeConfig(loss_scale=loss_scale),
train_dataset=dataset_factory.DatasetConfig(dtype=dtype))
ls = classifier_trainer.get_loss_scale(config, fp16_default=128)
self.assertEqual(ls, expected)
@parameterized.named_parameters(('float16', 'float16'),
('bfloat16', 'bfloat16'))
def test_initialize(self, dtype):
config = base_configs.ExperimentConfig(
runtime=base_configs.RuntimeConfig(
run_eagerly=False,
enable_xla=False,
per_gpu_thread_count=1,
gpu_thread_mode='gpu_private',
num_gpus=1,
dataset_num_private_threads=1,
),
train_dataset=dataset_factory.DatasetConfig(dtype=dtype),
model=base_configs.ModelConfig(),
)
class EmptyClass:
pass
fake_ds_builder = EmptyClass()
fake_ds_builder.dtype = dtype
fake_ds_builder.config = EmptyClass()
classifier_trainer.initialize(config, fake_ds_builder)
def test_resume_from_checkpoint(self):
"""Tests functionality for resuming from checkpoint."""
# Set the keras policy
policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16')
tf.keras.mixed_precision.experimental.set_policy(policy)
# Get the model, datasets, and compile it.
model = get_trivial_model(10)
# Create the checkpoint
model_dir = self.get_temp_dir()
train_epochs = 1
train_steps = 10
ds = get_trivial_data()
callbacks = [
tf.keras.callbacks.ModelCheckpoint(
os.path.join(model_dir, 'model.ckpt-{epoch:04d}'),
save_weights_only=True)
]
model.fit(
ds,
callbacks=callbacks,
epochs=train_epochs,
steps_per_epoch=train_steps)
# Test load from checkpoint
clean_model = get_trivial_model(10)
weights_before_load = copy.deepcopy(clean_model.get_weights())
initial_epoch = classifier_trainer.resume_from_checkpoint(
model=clean_model, model_dir=model_dir, train_steps=train_steps)
self.assertEqual(initial_epoch, 1)
self.assertNotAllClose(weights_before_load, clean_model.get_weights())
tf.io.gfile.rmtree(model_dir)
def test_serialize_config(self):
"""Tests functionality for serializing data."""
config = base_configs.ExperimentConfig()
model_dir = self.get_temp_dir()
classifier_trainer.serialize_config(params=config, model_dir=model_dir)
saved_params_path = os.path.join(model_dir, 'params.yaml')
self.assertTrue(os.path.exists(saved_params_path))
tf.io.gfile.rmtree(model_dir)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
# Lint as: python3
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Unit tests for the classifier trainer models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import os
from absl.testing import parameterized
import tensorflow as tf
from official.vision.image_classification import classifier_trainer
from official.vision.image_classification import dataset_factory
from official.vision.image_classification import test_utils
from official.vision.image_classification.configs import base_configs
def get_trivial_model(num_classes: int) -> tf.keras.Model:
"""Creates and compiles trivial model for ImageNet dataset."""
model = test_utils.trivial_model(num_classes=num_classes)
lr = 0.01
optimizer = tf.keras.optimizers.SGD(learning_rate=lr)
loss_obj = tf.keras.losses.SparseCategoricalCrossentropy()
model.compile(optimizer=optimizer, loss=loss_obj, run_eagerly=True)
return model
def get_trivial_data() -> tf.data.Dataset:
"""Gets trivial data in the ImageNet size."""
def generate_data(_) -> tf.data.Dataset:
image = tf.zeros(shape=(224, 224, 3), dtype=tf.float32)
label = tf.zeros([1], dtype=tf.int32)
return image, label
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
dataset = dataset.map(
generate_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.prefetch(buffer_size=1).batch(1)
return dataset
class UtilTests(parameterized.TestCase, tf.test.TestCase):
"""Tests for individual utility functions within classifier_trainer.py."""
@parameterized.named_parameters(
('efficientnet-b0', 'efficientnet', 'efficientnet-b0', 224),
('efficientnet-b1', 'efficientnet', 'efficientnet-b1', 240),
('efficientnet-b2', 'efficientnet', 'efficientnet-b2', 260),
('efficientnet-b3', 'efficientnet', 'efficientnet-b3', 300),
('efficientnet-b4', 'efficientnet', 'efficientnet-b4', 380),
('efficientnet-b5', 'efficientnet', 'efficientnet-b5', 456),
('efficientnet-b6', 'efficientnet', 'efficientnet-b6', 528),
('efficientnet-b7', 'efficientnet', 'efficientnet-b7', 600),
('resnet', 'resnet', '', None),
)
def test_get_model_size(self, model, model_name, expected):
config = base_configs.ExperimentConfig(
model_name=model,
model=base_configs.ModelConfig(
model_params={
'model_name': model_name,
},))
size = classifier_trainer.get_image_size_from_model(config)
self.assertEqual(size, expected)
@parameterized.named_parameters(
('dynamic', 'dynamic', None, 'dynamic'),
('scalar', 128., None, 128.),
('float32', None, 'float32', 1),
('float16', None, 'float16', 128),
)
def test_get_loss_scale(self, loss_scale, dtype, expected):
config = base_configs.ExperimentConfig(
runtime=base_configs.RuntimeConfig(loss_scale=loss_scale),
train_dataset=dataset_factory.DatasetConfig(dtype=dtype))
ls = classifier_trainer.get_loss_scale(config, fp16_default=128)
self.assertEqual(ls, expected)
@parameterized.named_parameters(('float16', 'float16'),
('bfloat16', 'bfloat16'))
def test_initialize(self, dtype):
config = base_configs.ExperimentConfig(
runtime=base_configs.RuntimeConfig(
run_eagerly=False,
enable_xla=False,
per_gpu_thread_count=1,
gpu_thread_mode='gpu_private',
num_gpus=1,
dataset_num_private_threads=1,
),
train_dataset=dataset_factory.DatasetConfig(dtype=dtype),
model=base_configs.ModelConfig(),
)
class EmptyClass:
pass
fake_ds_builder = EmptyClass()
fake_ds_builder.dtype = dtype
fake_ds_builder.config = EmptyClass()
classifier_trainer.initialize(config, fake_ds_builder)
def test_resume_from_checkpoint(self):
"""Tests functionality for resuming from checkpoint."""
# Set the keras policy
policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16')
tf.keras.mixed_precision.experimental.set_policy(policy)
# Get the model, datasets, and compile it.
model = get_trivial_model(10)
# Create the checkpoint
model_dir = self.create_tempdir().full_path
train_epochs = 1
train_steps = 10
ds = get_trivial_data()
callbacks = [
tf.keras.callbacks.ModelCheckpoint(
os.path.join(model_dir, 'model.ckpt-{epoch:04d}'),
save_weights_only=True)
]
model.fit(
ds,
callbacks=callbacks,
epochs=train_epochs,
steps_per_epoch=train_steps)
# Test load from checkpoint
clean_model = get_trivial_model(10)
weights_before_load = copy.deepcopy(clean_model.get_weights())
initial_epoch = classifier_trainer.resume_from_checkpoint(
model=clean_model, model_dir=model_dir, train_steps=train_steps)
self.assertEqual(initial_epoch, 1)
self.assertNotAllClose(weights_before_load, clean_model.get_weights())
tf.io.gfile.rmtree(model_dir)
def test_serialize_config(self):
"""Tests functionality for serializing data."""
config = base_configs.ExperimentConfig()
model_dir = self.create_tempdir().full_path
classifier_trainer.serialize_config(params=config, model_dir=model_dir)
saved_params_path = os.path.join(model_dir, 'params.yaml')
self.assertTrue(os.path.exists(saved_params_path))
tf.io.gfile.rmtree(model_dir)
if __name__ == '__main__':
tf.test.main()
...@@ -29,11 +29,14 @@ from official.utils.testing import integration ...@@ -29,11 +29,14 @@ from official.utils.testing import integration
from official.vision.image_classification import mnist_main from official.vision.image_classification import mnist_main
mnist_main.define_mnist_flags()
def eager_strategy_combinations(): def eager_strategy_combinations():
return combinations.combine( return combinations.combine(
distribution=[ distribution=[
strategy_combinations.default_strategy, strategy_combinations.default_strategy,
strategy_combinations.tpu_strategy, strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu, strategy_combinations.one_device_strategy_gpu,
], ],
mode="eager", mode="eager",
...@@ -47,7 +50,6 @@ class KerasMnistTest(tf.test.TestCase, parameterized.TestCase): ...@@ -47,7 +50,6 @@ class KerasMnistTest(tf.test.TestCase, parameterized.TestCase):
@classmethod @classmethod
def setUpClass(cls): # pylint: disable=invalid-name def setUpClass(cls): # pylint: disable=invalid-name
super(KerasMnistTest, cls).setUpClass() super(KerasMnistTest, cls).setUpClass()
mnist_main.define_mnist_flags()
def tearDown(self): def tearDown(self):
super(KerasMnistTest, self).tearDown() super(KerasMnistTest, self).tearDown()
...@@ -81,7 +83,7 @@ class KerasMnistTest(tf.test.TestCase, parameterized.TestCase): ...@@ -81,7 +83,7 @@ class KerasMnistTest(tf.test.TestCase, parameterized.TestCase):
integration.run_synthetic( integration.run_synthetic(
main=run, main=run,
synth=False, synth=False,
tmp_root=self.get_temp_dir(), tmp_root=self.create_tempdir().full_path,
extra_flags=extra_flags) extra_flags=extra_flags)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment