Commit f276d472 authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 304242451
parent a3be7365
...@@ -257,7 +257,6 @@ class RuntimeConfig(Config): ...@@ -257,7 +257,6 @@ class RuntimeConfig(Config):
Attributes: Attributes:
distribution_strategy: e.g. 'mirrored', 'tpu', etc. distribution_strategy: e.g. 'mirrored', 'tpu', etc.
enable_eager: Whether or not to enable eager mode.
enable_xla: Whether or not to enable XLA. enable_xla: Whether or not to enable XLA.
per_gpu_thread_count: thread count per GPU. per_gpu_thread_count: thread count per GPU.
gpu_threads_enabled: Whether or not GPU threads are enabled. gpu_threads_enabled: Whether or not GPU threads are enabled.
...@@ -272,9 +271,12 @@ class RuntimeConfig(Config): ...@@ -272,9 +271,12 @@ class RuntimeConfig(Config):
all_reduce_alg: Defines the algorithm for performing all-reduce. all_reduce_alg: Defines the algorithm for performing all-reduce.
num_packs: Sets `num_packs` in the cross device ops used in num_packs: Sets `num_packs` in the cross device ops used in
MirroredStrategy. For details, see tf.distribute.NcclAllReduce. MirroredStrategy. For details, see tf.distribute.NcclAllReduce.
loss_scale: The type of loss scale. This is used when setting the mixed
precision policy.
run_eagerly: Whether or not to run the experiment eagerly.
""" """
distribution_strategy: str = 'mirrored' distribution_strategy: str = 'mirrored'
enable_eager: bool = False
enable_xla: bool = False enable_xla: bool = False
gpu_threads_enabled: bool = False gpu_threads_enabled: bool = False
gpu_thread_mode: Optional[str] = None gpu_thread_mode: Optional[str] = None
...@@ -286,6 +288,8 @@ class RuntimeConfig(Config): ...@@ -286,6 +288,8 @@ class RuntimeConfig(Config):
task_index: int = -1 task_index: int = -1
all_reduce_alg: Optional[str] = None all_reduce_alg: Optional[str] = None
num_packs: int = 1 num_packs: int = 1
loss_scale: Optional[str] = None
run_eagerly: bool = False
@dataclasses.dataclass @dataclasses.dataclass
...@@ -312,7 +316,10 @@ class CallbacksConfig(Config): ...@@ -312,7 +316,10 @@ class CallbacksConfig(Config):
Callback. Defaults to True. Callback. Defaults to True.
enable_tensorboard: Whether or not to enable Tensorboard as a Callback. enable_tensorboard: Whether or not to enable Tensorboard as a Callback.
Defaults to True. Defaults to True.
enable_time_history: Whether or not to enable TimeHistory Callbacks.
Defaults to True.
""" """
enable_checkpoint_and_export: bool = True enable_checkpoint_and_export: bool = True
enable_tensorboard: bool = True enable_tensorboard: bool = True
enable_time_history: bool = True
# Lint as: python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -22,15 +23,20 @@ import os ...@@ -22,15 +23,20 @@ import os
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from typing import Any, List, MutableMapping, Text from typing import Any, List, MutableMapping
from official.utils.misc import keras_utils
def get_callbacks(model_checkpoint: bool = True, def get_callbacks(model_checkpoint: bool = True,
include_tensorboard: bool = True, include_tensorboard: bool = True,
time_history: bool = True,
track_lr: bool = True, track_lr: bool = True,
write_model_weights: bool = True, write_model_weights: bool = True,
initial_step: int = 0, initial_step: int = 0,
model_dir: Text = None) -> List[tf.keras.callbacks.Callback]: batch_size: int = 0,
log_steps: int = 0,
model_dir: str = None) -> List[tf.keras.callbacks.Callback]:
"""Get all callbacks.""" """Get all callbacks."""
model_dir = model_dir or '' model_dir = model_dir or ''
callbacks = [] callbacks = []
...@@ -44,6 +50,11 @@ def get_callbacks(model_checkpoint: bool = True, ...@@ -44,6 +50,11 @@ def get_callbacks(model_checkpoint: bool = True,
track_lr=track_lr, track_lr=track_lr,
initial_step=initial_step, initial_step=initial_step,
write_images=write_model_weights)) write_images=write_model_weights))
if time_history:
callbacks.append(keras_utils.TimeHistory(
batch_size,
log_steps,
logdir=model_dir if include_tensorboard else None))
return callbacks return callbacks
...@@ -74,7 +85,7 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard): ...@@ -74,7 +85,7 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
# classification loss # classification loss
def __init__(self, def __init__(self,
log_dir: Text, log_dir: str,
track_lr: bool = False, track_lr: bool = False,
initial_step: int = 0, initial_step: int = 0,
**kwargs): **kwargs):
...@@ -84,7 +95,7 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard): ...@@ -84,7 +95,7 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
def on_batch_begin(self, def on_batch_begin(self,
epoch: int, epoch: int,
logs: MutableMapping[Text, Any] = None) -> None: logs: MutableMapping[str, Any] = None) -> None:
self.step += 1 self.step += 1
if logs is None: if logs is None:
logs = {} logs = {}
...@@ -93,7 +104,7 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard): ...@@ -93,7 +104,7 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
def on_epoch_begin(self, def on_epoch_begin(self,
epoch: int, epoch: int,
logs: MutableMapping[Text, Any] = None) -> None: logs: MutableMapping[str, Any] = None) -> None:
if logs is None: if logs is None:
logs = {} logs = {}
metrics = self._calculate_metrics() metrics = self._calculate_metrics()
...@@ -104,14 +115,14 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard): ...@@ -104,14 +115,14 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
def on_epoch_end(self, def on_epoch_end(self,
epoch: int, epoch: int,
logs: MutableMapping[Text, Any] = None) -> None: logs: MutableMapping[str, Any] = None) -> None:
if logs is None: if logs is None:
logs = {} logs = {}
metrics = self._calculate_metrics() metrics = self._calculate_metrics()
logs.update(metrics) logs.update(metrics)
super(CustomTensorBoard, self).on_epoch_end(epoch, logs) super(CustomTensorBoard, self).on_epoch_end(epoch, logs)
def _calculate_metrics(self) -> MutableMapping[Text, Any]: def _calculate_metrics(self) -> MutableMapping[str, Any]:
logs = {} logs = {}
if self._track_lr: if self._track_lr:
logs['learning_rate'] = self._calculate_lr() logs['learning_rate'] = self._calculate_lr()
......
...@@ -44,10 +44,24 @@ from official.vision.image_classification.efficientnet import efficientnet_model ...@@ -44,10 +44,24 @@ from official.vision.image_classification.efficientnet import efficientnet_model
from official.vision.image_classification.resnet import common from official.vision.image_classification.resnet import common
from official.vision.image_classification.resnet import resnet_model from official.vision.image_classification.resnet import resnet_model
MODELS = {
def get_models() -> Mapping[str, tf.keras.Model]:
"""Returns the mapping from model type name to Keras model."""
return {
'efficientnet': efficientnet_model.EfficientNet.from_name, 'efficientnet': efficientnet_model.EfficientNet.from_name,
'resnet': resnet_model.resnet50, 'resnet': resnet_model.resnet50,
} }
def get_dtype_map() -> Mapping[str, tf.dtypes.DType]:
"""Returns the mapping from dtype string representations to TF dtypes."""
return {
'float32': tf.float32,
'bfloat16': tf.bfloat16,
'float16': tf.float16,
'fp32': tf.float32,
'bf16': tf.bfloat16,
}
def _get_metrics(one_hot: bool) -> Mapping[Text, Any]: def _get_metrics(one_hot: bool) -> Mapping[Text, Any]:
...@@ -120,7 +134,7 @@ def _get_dataset_builders(params: base_configs.ExperimentConfig, ...@@ -120,7 +134,7 @@ def _get_dataset_builders(params: base_configs.ExperimentConfig,
def get_loss_scale(params: base_configs.ExperimentConfig, def get_loss_scale(params: base_configs.ExperimentConfig,
fp16_default: float = 128.) -> float: fp16_default: float = 128.) -> float:
"""Returns the loss scale for initializations.""" """Returns the loss scale for initializations."""
loss_scale = params.model.loss.loss_scale loss_scale = params.runtime.loss_scale
if loss_scale == 'dynamic': if loss_scale == 'dynamic':
return loss_scale return loss_scale
elif loss_scale is not None: elif loss_scale is not None:
...@@ -145,7 +159,7 @@ def _get_params_from_flags(flags_obj: flags.FlagValues): ...@@ -145,7 +159,7 @@ def _get_params_from_flags(flags_obj: flags.FlagValues):
'name': model, 'name': model,
}, },
'runtime': { 'runtime': {
'enable_eager': flags_obj.enable_eager, 'run_eagerly': flags_obj.run_eagerly,
'tpu': flags_obj.tpu, 'tpu': flags_obj.tpu,
}, },
'train_dataset': { 'train_dataset': {
...@@ -154,6 +168,11 @@ def _get_params_from_flags(flags_obj: flags.FlagValues): ...@@ -154,6 +168,11 @@ def _get_params_from_flags(flags_obj: flags.FlagValues):
'validation_dataset': { 'validation_dataset': {
'data_dir': flags_obj.data_dir, 'data_dir': flags_obj.data_dir,
}, },
'train': {
'time_history': {
'log_steps': flags_obj.log_steps,
},
},
} }
overriding_configs = (flags_obj.config_file, overriding_configs = (flags_obj.config_file,
...@@ -209,10 +228,11 @@ def resume_from_checkpoint(model: tf.keras.Model, ...@@ -209,10 +228,11 @@ def resume_from_checkpoint(model: tf.keras.Model,
return int(initial_epoch) return int(initial_epoch)
def initialize(params: base_configs.ExperimentConfig): def initialize(params: base_configs.ExperimentConfig,
dataset_builder: dataset_factory.DatasetBuilder):
"""Initializes backend related initializations.""" """Initializes backend related initializations."""
keras_utils.set_session_config( keras_utils.set_session_config(
enable_eager=params.runtime.enable_eager, enable_eager=params.runtime.run_eagerly,
enable_xla=params.runtime.enable_xla) enable_xla=params.runtime.enable_xla)
if params.runtime.gpu_threads_enabled: if params.runtime.gpu_threads_enabled:
keras_utils.set_gpu_thread_mode_and_count( keras_utils.set_gpu_thread_mode_and_count(
...@@ -221,12 +241,11 @@ def initialize(params: base_configs.ExperimentConfig): ...@@ -221,12 +241,11 @@ def initialize(params: base_configs.ExperimentConfig):
num_gpus=params.runtime.num_gpus, num_gpus=params.runtime.num_gpus,
datasets_num_private_threads=params.runtime.dataset_num_private_threads) datasets_num_private_threads=params.runtime.dataset_num_private_threads)
dataset = params.train_dataset or params.validation_dataset performance.set_mixed_precision_policy(dataset_builder.dtype)
performance.set_mixed_precision_policy(dataset.dtype)
if dataset.data_format: if dataset_builder.config.data_format:
data_format = dataset.data_format data_format = dataset_builder.config.data_format
elif tf.config.list_physical_devices('GPU'): if tf.config.list_physical_devices('GPU'):
data_format = 'channels_first' data_format = 'channels_first'
else: else:
data_format = 'channels_last' data_format = 'channels_last'
...@@ -234,7 +253,7 @@ def initialize(params: base_configs.ExperimentConfig): ...@@ -234,7 +253,7 @@ def initialize(params: base_configs.ExperimentConfig):
distribution_utils.configure_cluster( distribution_utils.configure_cluster(
params.runtime.worker_hosts, params.runtime.worker_hosts,
params.runtime.task_index) params.runtime.task_index)
if params.runtime.enable_eager: if params.runtime.run_eagerly:
# Enable eager execution to allow step-by-step debugging # Enable eager execution to allow step-by-step debugging
tf.config.experimental_run_functions_eagerly(True) tf.config.experimental_run_functions_eagerly(True)
...@@ -251,7 +270,7 @@ def define_classifier_flags(): ...@@ -251,7 +270,7 @@ def define_classifier_flags():
default=None, default=None,
help='Mode to run: `train`, `eval`, `train_and_eval` or `export`.') help='Mode to run: `train`, `eval`, `train_and_eval` or `export`.')
flags.DEFINE_bool( flags.DEFINE_bool(
'enable_eager', 'run_eagerly',
default=None, default=None,
help='Use eager execution and disable autograph for debugging.') help='Use eager execution and disable autograph for debugging.')
flags.DEFINE_string( flags.DEFINE_string(
...@@ -262,6 +281,10 @@ def define_classifier_flags(): ...@@ -262,6 +281,10 @@ def define_classifier_flags():
'dataset', 'dataset',
default=None, default=None,
help='The name of the dataset, e.g. ImageNet, etc.') help='The name of the dataset, e.g. ImageNet, etc.')
flags.DEFINE_integer(
'log_steps',
default=100,
help='The interval of steps between logging of batch level stats.')
def serialize_config(params: base_configs.ExperimentConfig, def serialize_config(params: base_configs.ExperimentConfig,
...@@ -304,11 +327,13 @@ def train_and_eval( ...@@ -304,11 +327,13 @@ def train_and_eval(
train_steps = params.train.steps or train_builder.num_steps train_steps = params.train.steps or train_builder.num_steps
validation_steps = params.evaluation.steps or validation_builder.num_steps validation_steps = params.evaluation.steps or validation_builder.num_steps
initialize(params, train_builder)
logging.info('Global batch size: %d', train_builder.global_batch_size) logging.info('Global batch size: %d', train_builder.global_batch_size)
with strategy_scope: with strategy_scope:
model_params = params.model.model_params.as_dict() model_params = params.model.model_params.as_dict()
model = MODELS[params.model.name](**model_params) model = get_models()[params.model.name](**model_params)
learning_rate = optimizer_factory.build_learning_rate( learning_rate = optimizer_factory.build_learning_rate(
params=params.model.learning_rate, params=params.model.learning_rate,
batch_size=train_builder.global_batch_size, batch_size=train_builder.global_batch_size,
...@@ -328,8 +353,7 @@ def train_and_eval( ...@@ -328,8 +353,7 @@ def train_and_eval(
loss_obj = tf.keras.losses.SparseCategoricalCrossentropy() loss_obj = tf.keras.losses.SparseCategoricalCrossentropy()
model.compile(optimizer=optimizer, model.compile(optimizer=optimizer,
loss=loss_obj, loss=loss_obj,
metrics=metrics, metrics=metrics)
run_eagerly=params.runtime.enable_eager)
initial_epoch = 0 initial_epoch = 0
if params.train.resume_checkpoint: if params.train.resume_checkpoint:
...@@ -342,26 +366,37 @@ def train_and_eval( ...@@ -342,26 +366,37 @@ def train_and_eval(
callbacks = custom_callbacks.get_callbacks( callbacks = custom_callbacks.get_callbacks(
model_checkpoint=params.train.callbacks.enable_checkpoint_and_export, model_checkpoint=params.train.callbacks.enable_checkpoint_and_export,
include_tensorboard=params.train.callbacks.enable_tensorboard, include_tensorboard=params.train.callbacks.enable_tensorboard,
time_history=params.train.callbacks.enable_time_history,
track_lr=params.train.tensorboard.track_lr, track_lr=params.train.tensorboard.track_lr,
write_model_weights=params.train.tensorboard.write_model_weights, write_model_weights=params.train.tensorboard.write_model_weights,
initial_step=initial_epoch * train_steps, initial_step=initial_epoch * train_steps,
batch_size=train_builder.global_batch_size,
log_steps=params.train.time_history.log_steps,
model_dir=params.model_dir) model_dir=params.model_dir)
if params.evaluation.skip_eval:
validation_kwargs = {}
else:
validation_kwargs = {
'validation_data': validation_dataset,
'validation_steps': validation_steps,
'validation_freq': params.evaluation.epochs_between_evals,
}
history = model.fit( history = model.fit(
train_dataset, train_dataset,
epochs=train_epochs, epochs=train_epochs,
steps_per_epoch=train_steps, steps_per_epoch=train_steps,
initial_epoch=initial_epoch, initial_epoch=initial_epoch,
callbacks=callbacks, callbacks=callbacks,
validation_data=validation_dataset, **validation_kwargs)
validation_steps=validation_steps,
validation_freq=params.evaluation.epochs_between_evals)
validation_output = None
if not params.evaluation.skip_eval:
validation_output = model.evaluate( validation_output = model.evaluate(
validation_dataset, steps=validation_steps, verbose=2) validation_dataset, steps=validation_steps, verbose=2)
# TODO(dankondratyuk): eval and save final test accuracy # TODO(dankondratyuk): eval and save final test accuracy
stats = common.build_stats(history, stats = common.build_stats(history,
validation_output, validation_output,
callbacks) callbacks)
...@@ -372,7 +407,7 @@ def export(params: base_configs.ExperimentConfig): ...@@ -372,7 +407,7 @@ def export(params: base_configs.ExperimentConfig):
"""Runs the model export functionality.""" """Runs the model export functionality."""
logging.info('Exporting model.') logging.info('Exporting model.')
model_params = params.model.model_params.as_dict() model_params = params.model.model_params.as_dict()
model = MODELS[params.model.name](**model_params) model = get_models()[params.model.name](**model_params)
checkpoint = params.export.checkpoint checkpoint = params.export.checkpoint
if checkpoint is None: if checkpoint is None:
logging.info('No export checkpoint was provided. Using the latest ' logging.info('No export checkpoint was provided. Using the latest '
...@@ -395,8 +430,6 @@ def run(flags_obj: flags.FlagValues, ...@@ -395,8 +430,6 @@ def run(flags_obj: flags.FlagValues,
Dictionary of training/eval stats Dictionary of training/eval stats
""" """
params = _get_params_from_flags(flags_obj) params = _get_params_from_flags(flags_obj)
initialize(params)
if params.mode == 'train_and_eval': if params.mode == 'train_and_eval':
return train_and_eval(params, strategy_override) return train_and_eval(params, strategy_override)
elif params.mode == 'export_only': elif params.mode == 'export_only':
......
...@@ -233,8 +233,8 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase): ...@@ -233,8 +233,8 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
) )
def test_get_loss_scale(self, loss_scale, dtype, expected): def test_get_loss_scale(self, loss_scale, dtype, expected):
config = base_configs.ExperimentConfig( config = base_configs.ExperimentConfig(
model=base_configs.ModelConfig( runtime=base_configs.RuntimeConfig(
loss=base_configs.LossConfig(loss_scale=loss_scale)), loss_scale=loss_scale),
train_dataset=dataset_factory.DatasetConfig(dtype=dtype)) train_dataset=dataset_factory.DatasetConfig(dtype=dtype))
ls = classifier_trainer.get_loss_scale(config, fp16_default=128) ls = classifier_trainer.get_loss_scale(config, fp16_default=128)
self.assertEqual(ls, expected) self.assertEqual(ls, expected)
...@@ -246,7 +246,7 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase): ...@@ -246,7 +246,7 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
def test_initialize(self, dtype): def test_initialize(self, dtype):
config = base_configs.ExperimentConfig( config = base_configs.ExperimentConfig(
runtime=base_configs.RuntimeConfig( runtime=base_configs.RuntimeConfig(
enable_eager=False, run_eagerly=False,
enable_xla=False, enable_xla=False,
gpu_threads_enabled=True, gpu_threads_enabled=True,
per_gpu_thread_count=1, per_gpu_thread_count=1,
...@@ -258,7 +258,14 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase): ...@@ -258,7 +258,14 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
model=base_configs.ModelConfig( model=base_configs.ModelConfig(
loss=base_configs.LossConfig(loss_scale='dynamic')), loss=base_configs.LossConfig(loss_scale='dynamic')),
) )
classifier_trainer.initialize(config)
class EmptyClass:
pass
fake_ds_builder = EmptyClass()
fake_ds_builder.dtype = dtype
fake_ds_builder.config = EmptyClass()
fake_ds_builder.config.data_format = None
classifier_trainer.initialize(config, fake_ds_builder)
def test_resume_from_checkpoint(self): def test_resume_from_checkpoint(self):
"""Tests functionality for resuming from checkpoint.""" """Tests functionality for resuming from checkpoint."""
......
...@@ -58,6 +58,17 @@ class MetricsConfig(base_config.Config): ...@@ -58,6 +58,17 @@ class MetricsConfig(base_config.Config):
top_5: bool = None top_5: bool = None
@dataclasses.dataclass
class TimeHistoryConfig(base_config.Config):
"""Configuration for the TimeHistory callback.
Attributes:
log_steps: Interval of steps between logging of batch level stats.
"""
log_steps: int = None
@dataclasses.dataclass @dataclasses.dataclass
class TrainConfig(base_config.Config): class TrainConfig(base_config.Config):
"""Configuration for training. """Configuration for training.
...@@ -77,8 +88,9 @@ class TrainConfig(base_config.Config): ...@@ -77,8 +88,9 @@ class TrainConfig(base_config.Config):
epochs: int = None epochs: int = None
steps: int = None steps: int = None
callbacks: CallbacksConfig = CallbacksConfig() callbacks: CallbacksConfig = CallbacksConfig()
metrics: List[str] = None metrics: MetricsConfig = None
tensorboard: TensorboardConfig = TensorboardConfig() tensorboard: TensorboardConfig = TensorboardConfig()
time_history: TimeHistoryConfig = TimeHistoryConfig()
@dataclasses.dataclass @dataclasses.dataclass
...@@ -91,10 +103,12 @@ class EvalConfig(base_config.Config): ...@@ -91,10 +103,12 @@ class EvalConfig(base_config.Config):
steps: The number of eval steps to run during evaluation. If None, this will steps: The number of eval steps to run during evaluation. If None, this will
be inferred based on the number of images and batch size. Defaults to be inferred based on the number of images and batch size. Defaults to
None. None.
skip_eval: Whether or not to skip evaluation.
""" """
epochs_between_evals: int = None epochs_between_evals: int = None
steps: int = None steps: int = None
skip_eval: bool = False
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -52,6 +52,7 @@ class EfficientNetImageNetConfig(base_configs.ExperimentConfig): ...@@ -52,6 +52,7 @@ class EfficientNetImageNetConfig(base_configs.ExperimentConfig):
callbacks=base_configs.CallbacksConfig(enable_checkpoint_and_export=True, callbacks=base_configs.CallbacksConfig(enable_checkpoint_and_export=True,
enable_tensorboard=True), enable_tensorboard=True),
metrics=['accuracy', 'top_5'], metrics=['accuracy', 'top_5'],
time_history=base_configs.TimeHistoryConfig(log_steps=100),
tensorboard=base_configs.TensorboardConfig(track_lr=True, tensorboard=base_configs.TensorboardConfig(track_lr=True,
write_model_weights=False)) write_model_weights=False))
evaluation: base_configs.EvalConfig = base_configs.EvalConfig( evaluation: base_configs.EvalConfig = base_configs.EvalConfig(
...@@ -83,6 +84,7 @@ class ResNetImagenetConfig(base_configs.ExperimentConfig): ...@@ -83,6 +84,7 @@ class ResNetImagenetConfig(base_configs.ExperimentConfig):
callbacks=base_configs.CallbacksConfig(enable_checkpoint_and_export=True, callbacks=base_configs.CallbacksConfig(enable_checkpoint_and_export=True,
enable_tensorboard=True), enable_tensorboard=True),
metrics=['accuracy', 'top_5'], metrics=['accuracy', 'top_5'],
time_history=base_configs.TimeHistoryConfig(log_steps=100),
tensorboard=base_configs.TensorboardConfig(track_lr=True, tensorboard=base_configs.TensorboardConfig(track_lr=True,
write_model_weights=False)) write_model_weights=False))
evaluation: base_configs.EvalConfig = base_configs.EvalConfig( evaluation: base_configs.EvalConfig = base_configs.EvalConfig(
......
...@@ -203,6 +203,30 @@ class DatasetBuilder: ...@@ -203,6 +203,30 @@ class DatasetBuilder:
# Always divide by the global batch size to get the correct # of steps # Always divide by the global batch size to get the correct # of steps
return self.num_examples // self.global_batch_size return self.num_examples // self.global_batch_size
@property
def dtype(self) -> tf.dtypes.DType:
"""Converts the config's dtype string to a tf dtype.
Returns:
A mapping from string representation of a dtype to the `tf.dtypes.DType`.
Raises:
ValueError if the config's dtype is not supported.
"""
dtype_map = {
'float32': tf.float32,
'bfloat16': tf.bfloat16,
'float16': tf.float16,
'fp32': tf.float32,
'bf16': tf.bfloat16,
}
try:
return dtype_map[self.config.dtype]
except:
raise ValueError('Invalid DType provided. Supported types: {}'.format(
dtype_map.keys()))
@property @property
def image_size(self) -> int: def image_size(self) -> int:
"""The size of each image (can be inferred from the dataset).""" """The size of each image (can be inferred from the dataset)."""
...@@ -326,7 +350,7 @@ class DatasetBuilder: ...@@ -326,7 +350,7 @@ class DatasetBuilder:
def generate_data(_): def generate_data(_):
image = tf.zeros([self.image_size, self.image_size, self.num_channels], image = tf.zeros([self.image_size, self.image_size, self.num_channels],
dtype=self.config.dtype) dtype=self.dtype)
label = tf.zeros([1], dtype=tf.int32) label = tf.zeros([1], dtype=tf.int32)
return image, label return image, label
...@@ -451,7 +475,7 @@ class DatasetBuilder: ...@@ -451,7 +475,7 @@ class DatasetBuilder:
image_size=self.image_size, image_size=self.image_size,
mean_subtract=self.config.mean_subtract, mean_subtract=self.config.mean_subtract,
standardize=self.config.standardize, standardize=self.config.standardize,
dtype=self.config.dtype, dtype=self.dtype,
augmenter=self.augmenter) augmenter=self.augmenter)
else: else:
image = preprocessing.preprocess_for_eval( image = preprocessing.preprocess_for_eval(
...@@ -460,7 +484,7 @@ class DatasetBuilder: ...@@ -460,7 +484,7 @@ class DatasetBuilder:
num_channels=self.num_channels, num_channels=self.num_channels,
mean_subtract=self.config.mean_subtract, mean_subtract=self.config.mean_subtract,
standardize=self.config.standardize, standardize=self.config.standardize,
dtype=self.config.dtype) dtype=self.dtype)
label = tf.cast(label, tf.int32) label = tf.cast(label, tf.int32)
if self.config.one_hot: if self.config.one_hot:
......
...@@ -166,7 +166,6 @@ def build_stats(history, eval_output, callbacks): ...@@ -166,7 +166,6 @@ def build_stats(history, eval_output, callbacks):
if eval_output: if eval_output:
stats['accuracy_top_1'] = float(eval_output[1]) stats['accuracy_top_1'] = float(eval_output[1])
stats['eval_loss'] = float(eval_output[0]) stats['eval_loss'] = float(eval_output[0])
if history and history.history: if history and history.history:
train_hist = history.history train_hist = history.history
# Gets final loss from training. # Gets final loss from training.
...@@ -176,6 +175,8 @@ def build_stats(history, eval_output, callbacks): ...@@ -176,6 +175,8 @@ def build_stats(history, eval_output, callbacks):
stats[TRAIN_TOP_1] = float(train_hist['categorical_accuracy'][-1]) stats[TRAIN_TOP_1] = float(train_hist['categorical_accuracy'][-1])
elif 'sparse_categorical_accuracy' in train_hist: elif 'sparse_categorical_accuracy' in train_hist:
stats[TRAIN_TOP_1] = float(train_hist['sparse_categorical_accuracy'][-1]) stats[TRAIN_TOP_1] = float(train_hist['sparse_categorical_accuracy'][-1])
elif 'accuracy' in train_hist:
stats[TRAIN_TOP_1] = float(train_hist['accuracy'][-1])
if not callbacks: if not callbacks:
return stats return stats
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment