Commit 2d16421a authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Merge pull request #10449 from miguelCalado:vgg

PiperOrigin-RevId: 421667465
parents c9a7e0b2 3e7fe8a1
...@@ -152,6 +152,20 @@ python3 classifier_trainer.py \ ...@@ -152,6 +152,20 @@ python3 classifier_trainer.py \
--config_file=configs/examples/resnet/imagenet/tpu.yaml --config_file=configs/examples/resnet/imagenet/tpu.yaml
``` ```
### VGG-16
#### On GPU:
```bash
python3 classifier_trainer.py \
--mode=train_and_eval \
--model_type=vgg \
--dataset=imagenet \
--model_dir=$MODEL_DIR \
--data_dir=$DATA_DIR \
--config_file=configs/examples/vgg/imagenet/gpu.yaml \
--params_override='runtime.num_gpus=$NUM_GPUS'
```
### EfficientNet ### EfficientNet
**Note: EfficientNet development is a work in progress.** **Note: EfficientNet development is a work in progress.**
#### On GPU: #### On GPU:
......
...@@ -32,6 +32,7 @@ from official.legacy.image_classification.configs import configs ...@@ -32,6 +32,7 @@ from official.legacy.image_classification.configs import configs
from official.legacy.image_classification.efficientnet import efficientnet_model from official.legacy.image_classification.efficientnet import efficientnet_model
from official.legacy.image_classification.resnet import common from official.legacy.image_classification.resnet import common
from official.legacy.image_classification.resnet import resnet_model from official.legacy.image_classification.resnet import resnet_model
from official.legacy.image_classification.vgg import vgg_model
from official.modeling import hyperparams from official.modeling import hyperparams
from official.modeling import performance from official.modeling import performance
from official.utils import hyperparams_flags from official.utils import hyperparams_flags
...@@ -43,6 +44,7 @@ def get_models() -> Mapping[str, tf.keras.Model]: ...@@ -43,6 +44,7 @@ def get_models() -> Mapping[str, tf.keras.Model]:
return { return {
'efficientnet': efficientnet_model.EfficientNet.from_name, 'efficientnet': efficientnet_model.EfficientNet.from_name,
'resnet': resnet_model.resnet50, 'resnet': resnet_model.resnet50,
'vgg': vgg_model.vgg16,
} }
......
...@@ -15,10 +15,6 @@ ...@@ -15,10 +15,6 @@
# Lint as: python3 # Lint as: python3
"""Unit tests for the classifier trainer models.""" """Unit tests for the classifier trainer models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools import functools
import json import json
...@@ -53,6 +49,7 @@ def distribution_strategy_combinations() -> Iterable[Tuple[Any, ...]]: ...@@ -53,6 +49,7 @@ def distribution_strategy_combinations() -> Iterable[Tuple[Any, ...]]:
model=[ model=[
'efficientnet', 'efficientnet',
'resnet', 'resnet',
'vgg',
], ],
dataset=[ dataset=[
'imagenet', 'imagenet',
...@@ -149,6 +146,7 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase): ...@@ -149,6 +146,7 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
model=[ model=[
'efficientnet', 'efficientnet',
'resnet', 'resnet',
'vgg',
], ],
dataset='imagenet', dataset='imagenet',
dtype='float16', dtype='float16',
...@@ -193,6 +191,7 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase): ...@@ -193,6 +191,7 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
model=[ model=[
'efficientnet', 'efficientnet',
'resnet', 'resnet',
'vgg',
], ],
dataset='imagenet', dataset='imagenet',
dtype='bfloat16', dtype='bfloat16',
......
...@@ -14,9 +14,6 @@ ...@@ -14,9 +14,6 @@
# Lint as: python3 # Lint as: python3
"""Configuration utils for image classification experiments.""" """Configuration utils for image classification experiments."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import dataclasses import dataclasses
...@@ -24,6 +21,7 @@ from official.legacy.image_classification import dataset_factory ...@@ -24,6 +21,7 @@ from official.legacy.image_classification import dataset_factory
from official.legacy.image_classification.configs import base_configs from official.legacy.image_classification.configs import base_configs
from official.legacy.image_classification.efficientnet import efficientnet_config from official.legacy.image_classification.efficientnet import efficientnet_config
from official.legacy.image_classification.resnet import resnet_config from official.legacy.image_classification.resnet import resnet_config
from official.legacy.image_classification.vgg import vgg_config
@dataclasses.dataclass @dataclasses.dataclass
...@@ -92,12 +90,38 @@ class ResNetImagenetConfig(base_configs.ExperimentConfig): ...@@ -92,12 +90,38 @@ class ResNetImagenetConfig(base_configs.ExperimentConfig):
model: base_configs.ModelConfig = resnet_config.ResNetModelConfig() model: base_configs.ModelConfig = resnet_config.ResNetModelConfig()
@dataclasses.dataclass
class VGGImagenetConfig(base_configs.ExperimentConfig):
"""Base configuration to train vgg-16 on ImageNet."""
export: base_configs.ExportConfig = base_configs.ExportConfig()
runtime: base_configs.RuntimeConfig = base_configs.RuntimeConfig()
train_dataset: dataset_factory.DatasetConfig = dataset_factory.ImageNetConfig(
split='train', one_hot=False, mean_subtract=True, standardize=True)
validation_dataset: dataset_factory.DatasetConfig = dataset_factory.ImageNetConfig(
split='validation', one_hot=False, mean_subtract=True, standardize=True)
train: base_configs.TrainConfig = base_configs.TrainConfig(
resume_checkpoint=True,
epochs=90,
steps=None,
callbacks=base_configs.CallbacksConfig(
enable_checkpoint_and_export=True, enable_tensorboard=True),
metrics=['accuracy', 'top_5'],
time_history=base_configs.TimeHistoryConfig(log_steps=100),
tensorboard=base_configs.TensorBoardConfig(
track_lr=True, write_model_weights=False),
set_epoch_loop=False)
evaluation: base_configs.EvalConfig = base_configs.EvalConfig(
epochs_between_evals=1, steps=None)
model: base_configs.ModelConfig = vgg_config.VGGModelConfig()
def get_config(model: str, dataset: str) -> base_configs.ExperimentConfig: def get_config(model: str, dataset: str) -> base_configs.ExperimentConfig:
"""Given model and dataset names, return the ExperimentConfig.""" """Given model and dataset names, return the ExperimentConfig."""
dataset_model_config_map = { dataset_model_config_map = {
'imagenet': { 'imagenet': {
'efficientnet': EfficientNetImageNetConfig(), 'efficientnet': EfficientNetImageNetConfig(),
'resnet': ResNetImagenetConfig(), 'resnet': ResNetImagenetConfig(),
'vgg': VGGImagenetConfig(),
} }
} }
try: try:
......
# Training configuration for VGG-16 trained on ImageNet on GPUs.
# Reaches > 72.8% within 90 epochs.
# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
runtime:
distribution_strategy: 'mirrored'
num_gpus: 1
batchnorm_spatial_persistent: true
train_dataset:
name: 'imagenet2012'
data_dir: null
builder: 'records'
split: 'train'
image_size: 224
num_classes: 1000
num_examples: 1281167
batch_size: 128
use_per_replica_batch_size: true
dtype: 'float32'
mean_subtract: true
standardize: true
validation_dataset:
name: 'imagenet2012'
data_dir: null
builder: 'records'
split: 'validation'
image_size: 224
num_classes: 1000
num_examples: 50000
batch_size: 128
use_per_replica_batch_size: true
dtype: 'float32'
mean_subtract: true
standardize: true
model:
name: 'vgg'
optimizer:
name: 'momentum'
momentum: 0.9
epsilon: 0.001
loss:
label_smoothing: 0.0
train:
resume_checkpoint: true
epochs: 90
evaluation:
epochs_between_evals: 1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Configuration definitions for VGG losses, learning rates, and optimizers."""
import dataclasses
from official.legacy.image_classification.configs import base_configs
from official.modeling.hyperparams import base_config
@dataclasses.dataclass
class VGGModelConfig(base_configs.ModelConfig):
"""Configuration for the VGG model."""
name: str = 'VGG'
num_classes: int = 1000
model_params: base_config.Config = dataclasses.field(default_factory=lambda: { # pylint:disable=g-long-lambda
'num_classes': 1000,
'batch_size': None,
'use_l2_regularizer': True
})
loss: base_configs.LossConfig = base_configs.LossConfig(
name='sparse_categorical_crossentropy')
optimizer: base_configs.OptimizerConfig = base_configs.OptimizerConfig(
name='momentum', epsilon=0.001, momentum=0.9, moving_average_decay=None)
learning_rate: base_configs.LearningRateConfig = (
base_configs.LearningRateConfig(
name='stepwise',
initial_lr=0.01,
examples_per_epoch=1281167,
boundaries=[30, 60],
warmup_epochs=0,
scale_by_batch_size=1. / 256.,
multipliers=[0.01 / 256, 0.001 / 256, 0.0001 / 256]))
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""VGG16 model for Keras.
Adapted from tf.keras.applications.vgg16.VGG16().
Related papers/blogs:
- https://arxiv.org/abs/1409.1556
"""
import tensorflow as tf
layers = tf.keras.layers
def _gen_l2_regularizer(use_l2_regularizer=True, l2_weight_decay=1e-4):
return tf.keras.regularizers.L2(
l2_weight_decay) if use_l2_regularizer else None
def vgg16(num_classes,
batch_size=None,
use_l2_regularizer=True,
batch_norm_decay=0.9,
batch_norm_epsilon=1e-5):
"""Instantiates the VGG16 architecture.
Args:
num_classes: `int` number of classes for image classification.
batch_size: Size of the batches for each step.
use_l2_regularizer: whether to use L2 regularizer on Conv/Dense layer.
batch_norm_decay: Moment of batch norm layers.
batch_norm_epsilon: Epsilon of batch borm layers.
Returns:
A Keras model instance.
"""
input_shape = (224, 224, 3)
img_input = layers.Input(shape=input_shape, batch_size=batch_size)
x = img_input
if tf.keras.backend.image_data_format() == 'channels_first':
x = layers.Permute((3, 1, 2))(x)
bn_axis = 1
else: # channels_last
bn_axis = 3
# Block 1
x = layers.Conv2D(
64, (3, 3),
padding='same',
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name='block1_conv1')(
x)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=batch_norm_decay,
epsilon=batch_norm_epsilon,
name='bn_conv1')(
x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(
64, (3, 3),
padding='same',
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name='block1_conv2')(
x)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=batch_norm_decay,
epsilon=batch_norm_epsilon,
name='bn_conv2')(
x)
x = layers.Activation('relu')(x)
x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x)
# Block 2
x = layers.Conv2D(
128, (3, 3),
padding='same',
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name='block2_conv1')(
x)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=batch_norm_decay,
epsilon=batch_norm_epsilon,
name='bn_conv3')(
x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(
128, (3, 3),
padding='same',
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name='block2_conv2')(
x)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=batch_norm_decay,
epsilon=batch_norm_epsilon,
name='bn_conv4')(
x)
x = layers.Activation('relu')(x)
x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x)
# Block 3
x = layers.Conv2D(
256, (3, 3),
padding='same',
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name='block3_conv1')(
x)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=batch_norm_decay,
epsilon=batch_norm_epsilon,
name='bn_conv5')(
x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(
256, (3, 3),
padding='same',
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name='block3_conv2')(
x)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=batch_norm_decay,
epsilon=batch_norm_epsilon,
name='bn_conv6')(
x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(
256, (3, 3),
padding='same',
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name='block3_conv3')(
x)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=batch_norm_decay,
epsilon=batch_norm_epsilon,
name='bn_conv7')(
x)
x = layers.Activation('relu')(x)
x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x)
# Block 4
x = layers.Conv2D(
512, (3, 3),
padding='same',
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name='block4_conv1')(
x)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=batch_norm_decay,
epsilon=batch_norm_epsilon,
name='bn_conv8')(
x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(
512, (3, 3),
padding='same',
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name='block4_conv2')(
x)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=batch_norm_decay,
epsilon=batch_norm_epsilon,
name='bn_conv9')(
x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(
512, (3, 3),
padding='same',
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name='block4_conv3')(
x)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=batch_norm_decay,
epsilon=batch_norm_epsilon,
name='bn_conv10')(
x)
x = layers.Activation('relu')(x)
x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x)
# Block 5
x = layers.Conv2D(
512, (3, 3),
padding='same',
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name='block5_conv1')(
x)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=batch_norm_decay,
epsilon=batch_norm_epsilon,
name='bn_conv11')(
x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(
512, (3, 3),
padding='same',
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name='block5_conv2')(
x)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=batch_norm_decay,
epsilon=batch_norm_epsilon,
name='bn_conv12')(
x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(
512, (3, 3),
padding='same',
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name='block5_conv3')(
x)
x = layers.BatchNormalization(
axis=bn_axis,
momentum=batch_norm_decay,
epsilon=batch_norm_epsilon,
name='bn_conv13')(
x)
x = layers.Activation('relu')(x)
x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(x)
x = layers.Flatten(name='flatten')(x)
x = layers.Dense(
4096,
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name='fc1')(
x)
x = layers.Activation('relu')(x)
x = layers.Dropout(0.5)(x)
x = layers.Dense(
4096,
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name='fc2')(
x)
x = layers.Activation('relu')(x)
x = layers.Dropout(0.5)(x)
x = layers.Dense(
num_classes,
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
name='fc1000')(
x)
x = layers.Activation('softmax', dtype='float32')(x)
# Create model.
return tf.keras.Model(img_input, x, name='vgg16')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment