Commit f276d472 authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 304242451
parent a3be7365
......@@ -257,7 +257,6 @@ class RuntimeConfig(Config):
Attributes:
distribution_strategy: e.g. 'mirrored', 'tpu', etc.
enable_eager: Whether or not to enable eager mode.
enable_xla: Whether or not to enable XLA.
per_gpu_thread_count: thread count per GPU.
gpu_threads_enabled: Whether or not GPU threads are enabled.
......@@ -272,9 +271,12 @@ class RuntimeConfig(Config):
all_reduce_alg: Defines the algorithm for performing all-reduce.
num_packs: Sets `num_packs` in the cross device ops used in
MirroredStrategy. For details, see tf.distribute.NcclAllReduce.
loss_scale: The type of loss scale. This is used when setting the mixed
precision policy.
run_eagerly: Whether or not to run the experiment eagerly.
"""
distribution_strategy: str = 'mirrored'
enable_eager: bool = False
enable_xla: bool = False
gpu_threads_enabled: bool = False
gpu_thread_mode: Optional[str] = None
......@@ -286,6 +288,8 @@ class RuntimeConfig(Config):
task_index: int = -1
all_reduce_alg: Optional[str] = None
num_packs: int = 1
loss_scale: Optional[str] = None
run_eagerly: bool = False
@dataclasses.dataclass
......@@ -312,7 +316,10 @@ class CallbacksConfig(Config):
Callback. Defaults to True.
enable_tensorboard: Whether or not to enable Tensorboard as a Callback.
Defaults to True.
enable_time_history: Whether or not to enable TimeHistory Callbacks.
Defaults to True.
"""
enable_checkpoint_and_export: bool = True
enable_tensorboard: bool = True
enable_time_history: bool = True
# Lint as: python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -22,15 +23,20 @@ import os
from absl import logging
import tensorflow as tf
from typing import Any, List, MutableMapping, Text
from typing import Any, List, MutableMapping
from official.utils.misc import keras_utils
def get_callbacks(model_checkpoint: bool = True,
include_tensorboard: bool = True,
time_history: bool = True,
track_lr: bool = True,
write_model_weights: bool = True,
initial_step: int = 0,
model_dir: Text = None) -> List[tf.keras.callbacks.Callback]:
batch_size: int = 0,
log_steps: int = 0,
model_dir: str = None) -> List[tf.keras.callbacks.Callback]:
"""Get all callbacks."""
model_dir = model_dir or ''
callbacks = []
......@@ -44,6 +50,11 @@ def get_callbacks(model_checkpoint: bool = True,
track_lr=track_lr,
initial_step=initial_step,
write_images=write_model_weights))
if time_history:
callbacks.append(keras_utils.TimeHistory(
batch_size,
log_steps,
logdir=model_dir if include_tensorboard else None))
return callbacks
......@@ -74,7 +85,7 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
# classification loss
def __init__(self,
log_dir: Text,
log_dir: str,
track_lr: bool = False,
initial_step: int = 0,
**kwargs):
......@@ -84,7 +95,7 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
def on_batch_begin(self,
epoch: int,
logs: MutableMapping[Text, Any] = None) -> None:
logs: MutableMapping[str, Any] = None) -> None:
self.step += 1
if logs is None:
logs = {}
......@@ -93,7 +104,7 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
def on_epoch_begin(self,
epoch: int,
logs: MutableMapping[Text, Any] = None) -> None:
logs: MutableMapping[str, Any] = None) -> None:
if logs is None:
logs = {}
metrics = self._calculate_metrics()
......@@ -104,14 +115,14 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
def on_epoch_end(self,
epoch: int,
logs: MutableMapping[Text, Any] = None) -> None:
logs: MutableMapping[str, Any] = None) -> None:
if logs is None:
logs = {}
metrics = self._calculate_metrics()
logs.update(metrics)
super(CustomTensorBoard, self).on_epoch_end(epoch, logs)
def _calculate_metrics(self) -> MutableMapping[Text, Any]:
def _calculate_metrics(self) -> MutableMapping[str, Any]:
logs = {}
if self._track_lr:
logs['learning_rate'] = self._calculate_lr()
......
......@@ -44,10 +44,24 @@ from official.vision.image_classification.efficientnet import efficientnet_model
from official.vision.image_classification.resnet import common
from official.vision.image_classification.resnet import resnet_model
MODELS = {
'efficientnet': efficientnet_model.EfficientNet.from_name,
'resnet': resnet_model.resnet50,
}
def get_models() -> Mapping[str, tf.keras.Model]:
"""Returns the mapping from model type name to Keras model."""
return {
'efficientnet': efficientnet_model.EfficientNet.from_name,
'resnet': resnet_model.resnet50,
}
def get_dtype_map() -> Mapping[str, tf.dtypes.DType]:
"""Returns the mapping from dtype string representations to TF dtypes."""
return {
'float32': tf.float32,
'bfloat16': tf.bfloat16,
'float16': tf.float16,
'fp32': tf.float32,
'bf16': tf.bfloat16,
}
def _get_metrics(one_hot: bool) -> Mapping[Text, Any]:
......@@ -120,7 +134,7 @@ def _get_dataset_builders(params: base_configs.ExperimentConfig,
def get_loss_scale(params: base_configs.ExperimentConfig,
fp16_default: float = 128.) -> float:
"""Returns the loss scale for initializations."""
loss_scale = params.model.loss.loss_scale
loss_scale = params.runtime.loss_scale
if loss_scale == 'dynamic':
return loss_scale
elif loss_scale is not None:
......@@ -145,7 +159,7 @@ def _get_params_from_flags(flags_obj: flags.FlagValues):
'name': model,
},
'runtime': {
'enable_eager': flags_obj.enable_eager,
'run_eagerly': flags_obj.run_eagerly,
'tpu': flags_obj.tpu,
},
'train_dataset': {
......@@ -154,6 +168,11 @@ def _get_params_from_flags(flags_obj: flags.FlagValues):
'validation_dataset': {
'data_dir': flags_obj.data_dir,
},
'train': {
'time_history': {
'log_steps': flags_obj.log_steps,
},
},
}
overriding_configs = (flags_obj.config_file,
......@@ -209,10 +228,11 @@ def resume_from_checkpoint(model: tf.keras.Model,
return int(initial_epoch)
def initialize(params: base_configs.ExperimentConfig):
def initialize(params: base_configs.ExperimentConfig,
dataset_builder: dataset_factory.DatasetBuilder):
"""Initializes backend related initializations."""
keras_utils.set_session_config(
enable_eager=params.runtime.enable_eager,
enable_eager=params.runtime.run_eagerly,
enable_xla=params.runtime.enable_xla)
if params.runtime.gpu_threads_enabled:
keras_utils.set_gpu_thread_mode_and_count(
......@@ -221,12 +241,11 @@ def initialize(params: base_configs.ExperimentConfig):
num_gpus=params.runtime.num_gpus,
datasets_num_private_threads=params.runtime.dataset_num_private_threads)
dataset = params.train_dataset or params.validation_dataset
performance.set_mixed_precision_policy(dataset.dtype)
performance.set_mixed_precision_policy(dataset_builder.dtype)
if dataset.data_format:
data_format = dataset.data_format
elif tf.config.list_physical_devices('GPU'):
if dataset_builder.config.data_format:
data_format = dataset_builder.config.data_format
if tf.config.list_physical_devices('GPU'):
data_format = 'channels_first'
else:
data_format = 'channels_last'
......@@ -234,7 +253,7 @@ def initialize(params: base_configs.ExperimentConfig):
distribution_utils.configure_cluster(
params.runtime.worker_hosts,
params.runtime.task_index)
if params.runtime.enable_eager:
if params.runtime.run_eagerly:
# Enable eager execution to allow step-by-step debugging
tf.config.experimental_run_functions_eagerly(True)
......@@ -251,7 +270,7 @@ def define_classifier_flags():
default=None,
help='Mode to run: `train`, `eval`, `train_and_eval` or `export`.')
flags.DEFINE_bool(
'enable_eager',
'run_eagerly',
default=None,
help='Use eager execution and disable autograph for debugging.')
flags.DEFINE_string(
......@@ -262,6 +281,10 @@ def define_classifier_flags():
'dataset',
default=None,
help='The name of the dataset, e.g. ImageNet, etc.')
flags.DEFINE_integer(
'log_steps',
default=100,
help='The interval of steps between logging of batch level stats.')
def serialize_config(params: base_configs.ExperimentConfig,
......@@ -304,11 +327,13 @@ def train_and_eval(
train_steps = params.train.steps or train_builder.num_steps
validation_steps = params.evaluation.steps or validation_builder.num_steps
initialize(params, train_builder)
logging.info('Global batch size: %d', train_builder.global_batch_size)
with strategy_scope:
model_params = params.model.model_params.as_dict()
model = MODELS[params.model.name](**model_params)
model = get_models()[params.model.name](**model_params)
learning_rate = optimizer_factory.build_learning_rate(
params=params.model.learning_rate,
batch_size=train_builder.global_batch_size,
......@@ -328,8 +353,7 @@ def train_and_eval(
loss_obj = tf.keras.losses.SparseCategoricalCrossentropy()
model.compile(optimizer=optimizer,
loss=loss_obj,
metrics=metrics,
run_eagerly=params.runtime.enable_eager)
metrics=metrics)
initial_epoch = 0
if params.train.resume_checkpoint:
......@@ -342,26 +366,37 @@ def train_and_eval(
callbacks = custom_callbacks.get_callbacks(
model_checkpoint=params.train.callbacks.enable_checkpoint_and_export,
include_tensorboard=params.train.callbacks.enable_tensorboard,
time_history=params.train.callbacks.enable_time_history,
track_lr=params.train.tensorboard.track_lr,
write_model_weights=params.train.tensorboard.write_model_weights,
initial_step=initial_epoch * train_steps,
batch_size=train_builder.global_batch_size,
log_steps=params.train.time_history.log_steps,
model_dir=params.model_dir)
if params.evaluation.skip_eval:
validation_kwargs = {}
else:
validation_kwargs = {
'validation_data': validation_dataset,
'validation_steps': validation_steps,
'validation_freq': params.evaluation.epochs_between_evals,
}
history = model.fit(
train_dataset,
epochs=train_epochs,
steps_per_epoch=train_steps,
initial_epoch=initial_epoch,
callbacks=callbacks,
validation_data=validation_dataset,
validation_steps=validation_steps,
validation_freq=params.evaluation.epochs_between_evals)
**validation_kwargs)
validation_output = model.evaluate(
validation_dataset, steps=validation_steps, verbose=2)
validation_output = None
if not params.evaluation.skip_eval:
validation_output = model.evaluate(
validation_dataset, steps=validation_steps, verbose=2)
# TODO(dankondratyuk): eval and save final test accuracy
stats = common.build_stats(history,
validation_output,
callbacks)
......@@ -372,7 +407,7 @@ def export(params: base_configs.ExperimentConfig):
"""Runs the model export functionality."""
logging.info('Exporting model.')
model_params = params.model.model_params.as_dict()
model = MODELS[params.model.name](**model_params)
model = get_models()[params.model.name](**model_params)
checkpoint = params.export.checkpoint
if checkpoint is None:
logging.info('No export checkpoint was provided. Using the latest '
......@@ -395,8 +430,6 @@ def run(flags_obj: flags.FlagValues,
Dictionary of training/eval stats
"""
params = _get_params_from_flags(flags_obj)
initialize(params)
if params.mode == 'train_and_eval':
return train_and_eval(params, strategy_override)
elif params.mode == 'export_only':
......
......@@ -233,8 +233,8 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
)
def test_get_loss_scale(self, loss_scale, dtype, expected):
config = base_configs.ExperimentConfig(
model=base_configs.ModelConfig(
loss=base_configs.LossConfig(loss_scale=loss_scale)),
runtime=base_configs.RuntimeConfig(
loss_scale=loss_scale),
train_dataset=dataset_factory.DatasetConfig(dtype=dtype))
ls = classifier_trainer.get_loss_scale(config, fp16_default=128)
self.assertEqual(ls, expected)
......@@ -246,7 +246,7 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
def test_initialize(self, dtype):
config = base_configs.ExperimentConfig(
runtime=base_configs.RuntimeConfig(
enable_eager=False,
run_eagerly=False,
enable_xla=False,
gpu_threads_enabled=True,
per_gpu_thread_count=1,
......@@ -258,7 +258,14 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
model=base_configs.ModelConfig(
loss=base_configs.LossConfig(loss_scale='dynamic')),
)
classifier_trainer.initialize(config)
class EmptyClass:
pass
fake_ds_builder = EmptyClass()
fake_ds_builder.dtype = dtype
fake_ds_builder.config = EmptyClass()
fake_ds_builder.config.data_format = None
classifier_trainer.initialize(config, fake_ds_builder)
def test_resume_from_checkpoint(self):
"""Tests functionality for resuming from checkpoint."""
......
......@@ -58,6 +58,17 @@ class MetricsConfig(base_config.Config):
top_5: bool = None
@dataclasses.dataclass
class TimeHistoryConfig(base_config.Config):
"""Configuration for the TimeHistory callback.
Attributes:
log_steps: Interval of steps between logging of batch level stats.
"""
log_steps: int = None
@dataclasses.dataclass
class TrainConfig(base_config.Config):
"""Configuration for training.
......@@ -77,8 +88,9 @@ class TrainConfig(base_config.Config):
epochs: int = None
steps: int = None
callbacks: CallbacksConfig = CallbacksConfig()
metrics: List[str] = None
metrics: MetricsConfig = None
tensorboard: TensorboardConfig = TensorboardConfig()
time_history: TimeHistoryConfig = TimeHistoryConfig()
@dataclasses.dataclass
......@@ -91,10 +103,12 @@ class EvalConfig(base_config.Config):
steps: The number of eval steps to run during evaluation. If None, this will
be inferred based on the number of images and batch size. Defaults to
None.
skip_eval: Whether or not to skip evaluation.
"""
epochs_between_evals: int = None
steps: int = None
skip_eval: bool = False
@dataclasses.dataclass
......
......@@ -52,6 +52,7 @@ class EfficientNetImageNetConfig(base_configs.ExperimentConfig):
callbacks=base_configs.CallbacksConfig(enable_checkpoint_and_export=True,
enable_tensorboard=True),
metrics=['accuracy', 'top_5'],
time_history=base_configs.TimeHistoryConfig(log_steps=100),
tensorboard=base_configs.TensorboardConfig(track_lr=True,
write_model_weights=False))
evaluation: base_configs.EvalConfig = base_configs.EvalConfig(
......@@ -83,6 +84,7 @@ class ResNetImagenetConfig(base_configs.ExperimentConfig):
callbacks=base_configs.CallbacksConfig(enable_checkpoint_and_export=True,
enable_tensorboard=True),
metrics=['accuracy', 'top_5'],
time_history=base_configs.TimeHistoryConfig(log_steps=100),
tensorboard=base_configs.TensorboardConfig(track_lr=True,
write_model_weights=False))
evaluation: base_configs.EvalConfig = base_configs.EvalConfig(
......
......@@ -203,6 +203,30 @@ class DatasetBuilder:
# Always divide by the global batch size to get the correct # of steps
return self.num_examples // self.global_batch_size
@property
def dtype(self) -> tf.dtypes.DType:
"""Converts the config's dtype string to a tf dtype.
Returns:
A mapping from string representation of a dtype to the `tf.dtypes.DType`.
Raises:
ValueError if the config's dtype is not supported.
"""
dtype_map = {
'float32': tf.float32,
'bfloat16': tf.bfloat16,
'float16': tf.float16,
'fp32': tf.float32,
'bf16': tf.bfloat16,
}
try:
return dtype_map[self.config.dtype]
except:
raise ValueError('Invalid DType provided. Supported types: {}'.format(
dtype_map.keys()))
@property
def image_size(self) -> int:
"""The size of each image (can be inferred from the dataset)."""
......@@ -326,7 +350,7 @@ class DatasetBuilder:
def generate_data(_):
image = tf.zeros([self.image_size, self.image_size, self.num_channels],
dtype=self.config.dtype)
dtype=self.dtype)
label = tf.zeros([1], dtype=tf.int32)
return image, label
......@@ -451,7 +475,7 @@ class DatasetBuilder:
image_size=self.image_size,
mean_subtract=self.config.mean_subtract,
standardize=self.config.standardize,
dtype=self.config.dtype,
dtype=self.dtype,
augmenter=self.augmenter)
else:
image = preprocessing.preprocess_for_eval(
......@@ -460,7 +484,7 @@ class DatasetBuilder:
num_channels=self.num_channels,
mean_subtract=self.config.mean_subtract,
standardize=self.config.standardize,
dtype=self.config.dtype)
dtype=self.dtype)
label = tf.cast(label, tf.int32)
if self.config.one_hot:
......
......@@ -166,7 +166,6 @@ def build_stats(history, eval_output, callbacks):
if eval_output:
stats['accuracy_top_1'] = float(eval_output[1])
stats['eval_loss'] = float(eval_output[0])
if history and history.history:
train_hist = history.history
# Gets final loss from training.
......@@ -176,6 +175,8 @@ def build_stats(history, eval_output, callbacks):
stats[TRAIN_TOP_1] = float(train_hist['categorical_accuracy'][-1])
elif 'sparse_categorical_accuracy' in train_hist:
stats[TRAIN_TOP_1] = float(train_hist['sparse_categorical_accuracy'][-1])
elif 'accuracy' in train_hist:
stats[TRAIN_TOP_1] = float(train_hist['accuracy'][-1])
if not callbacks:
return stats
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment