Commit 32e4ca51 authored by qianyj's avatar qianyj
Browse files

Update code to v2.11.0

parents 9485aa1d 71060f67
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Image Classification
**Warning:** the features in the `image_classification/` folder have been fully
integrated into vision/beta. Please use the [new code base](../../vision/beta/README.md).
**Warning:** the features in the `image_classification/` directory have been
fully integrated into the [new code base](https://github.com/tensorflow/models/tree/benchmark/official/vision/modeling/backbones).
This folder contains TF 2.0 model examples for image classification:
This folder contains TF 2 model examples for image classification:
* [MNIST](#mnist)
* [Classifier Trainer](#classifier-trainer), a framework that uses the Keras
......@@ -17,8 +17,7 @@ For more information about other types of models, please refer to this
## Before you begin
Please make sure that you have the latest version of TensorFlow
installed and
[add the models folder to your Python path](/official/#running-the-models).
installed and add the models folder to your Python path.
### ImageNet preparation
......@@ -70,6 +69,7 @@ available GPUs at each host.
To download the data and run the MNIST sample model locally for the first time,
run one of the following command:
<details>
```bash
python3 mnist_main.py \
--model_dir=$MODEL_DIR \
......@@ -79,9 +79,11 @@ python3 mnist_main.py \
--num_gpus=$NUM_GPUS \
--download
```
</details>
To train the model on a Cloud TPU, run the following command:
<details>
```bash
python3 mnist_main.py \
--tpu=$TPU_NAME \
......@@ -91,10 +93,10 @@ python3 mnist_main.py \
--distribution_strategy=tpu \
--download
```
</details>
Note: the `--download` flag is only required the first time you run the model.
## Classifier Trainer
The classifier trainer is a unified framework for running image classification
models using Keras's compile/fit methods. Experiments should be provided in the
......@@ -111,6 +113,8 @@ be 64 * 8 = 512, and for a v3-32, the global batch size is 64 * 32 = 2048.
### ResNet50
#### On GPU:
<details>
```bash
python3 classifier_trainer.py \
--mode=train_and_eval \
......@@ -121,12 +125,15 @@ python3 classifier_trainer.py \
--config_file=configs/examples/resnet/imagenet/gpu.yaml \
--params_override='runtime.num_gpus=$NUM_GPUS'
```
</details>
To train on multiple hosts, each with GPUs attached using
[MultiWorkerMirroredStrategy](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/MultiWorkerMirroredStrategy)
please update `runtime` section in gpu.yaml
(or override using `--params_override`) with:
<details>
```YAML
# gpu.yaml
runtime:
......@@ -135,12 +142,16 @@ runtime:
num_gpus: $NUM_GPUS
task_index: 0
```
</details>
By having `task_index: 0` on the first host and `task_index: 1` on the second
and so on. `$HOST1` and `$HOST2` are the IP addresses of the hosts, and `port`
can be chosen any free port on the hosts. Only the first host will write
TensorBoard Summaries and save checkpoints.
#### On TPU:
<details>
```bash
python3 classifier_trainer.py \
--mode=train_and_eval \
......@@ -152,9 +163,31 @@ python3 classifier_trainer.py \
--config_file=configs/examples/resnet/imagenet/tpu.yaml
```
</details>
### VGG-16
#### On GPU:
<details>
```bash
python3 classifier_trainer.py \
--mode=train_and_eval \
--model_type=vgg \
--dataset=imagenet \
--model_dir=$MODEL_DIR \
--data_dir=$DATA_DIR \
--config_file=configs/examples/vgg/imagenet/gpu.yaml \
--params_override='runtime.num_gpus=$NUM_GPUS'
```
</details>
### EfficientNet
**Note: EfficientNet development is a work in progress.**
#### On GPU:
<details>
```bash
python3 classifier_trainer.py \
--mode=train_and_eval \
......@@ -166,8 +199,11 @@ python3 classifier_trainer.py \
--params_override='runtime.num_gpus=$NUM_GPUS'
```
</details>
#### On TPU:
<details>
```bash
python3 classifier_trainer.py \
--mode=train_and_eval \
......@@ -178,6 +214,7 @@ python3 classifier_trainer.py \
--data_dir=$DATA_DIR \
--config_file=configs/examples/efficientnet/imagenet/efficientnet-b0-tpu.yaml
```
</details>
Note that the number of GPU devices can be overridden in the command line using
`--params_overrides`. The TPU does not need this override as the device is fixed
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Common modules for callbacks."""
from __future__ import absolute_import
from __future__ import division
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Runs an Image Classification model."""
import os
......@@ -32,6 +31,7 @@ from official.legacy.image_classification.configs import configs
from official.legacy.image_classification.efficientnet import efficientnet_model
from official.legacy.image_classification.resnet import common
from official.legacy.image_classification.resnet import resnet_model
from official.legacy.image_classification.vgg import vgg_model
from official.modeling import hyperparams
from official.modeling import performance
from official.utils import hyperparams_flags
......@@ -43,6 +43,7 @@ def get_models() -> Mapping[str, tf.keras.Model]:
return {
'efficientnet': efficientnet_model.EfficientNet.from_name,
'resnet': resnet_model.resnet50,
'vgg': vgg_model.vgg16,
}
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,13 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Unit tests for the classifier trainer models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import json
......@@ -53,6 +48,7 @@ def distribution_strategy_combinations() -> Iterable[Tuple[Any, ...]]:
model=[
'efficientnet',
'resnet',
'vgg',
],
dataset=[
'imagenet',
......@@ -149,6 +145,7 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
model=[
'efficientnet',
'resnet',
'vgg',
],
dataset='imagenet',
dtype='float16',
......@@ -193,6 +190,7 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
model=[
'efficientnet',
'resnet',
'vgg',
],
dataset='imagenet',
dtype='bfloat16',
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Unit tests for the classifier trainer models."""
from __future__ import absolute_import
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Definitions for high level configuration groups.."""
import dataclasses
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,11 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Configuration utils for image classification experiments."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import dataclasses
......@@ -24,6 +20,7 @@ from official.legacy.image_classification import dataset_factory
from official.legacy.image_classification.configs import base_configs
from official.legacy.image_classification.efficientnet import efficientnet_config
from official.legacy.image_classification.resnet import resnet_config
from official.legacy.image_classification.vgg import vgg_config
@dataclasses.dataclass
......@@ -92,12 +89,38 @@ class ResNetImagenetConfig(base_configs.ExperimentConfig):
model: base_configs.ModelConfig = resnet_config.ResNetModelConfig()
@dataclasses.dataclass
class VGGImagenetConfig(base_configs.ExperimentConfig):
"""Base configuration to train vgg-16 on ImageNet."""
export: base_configs.ExportConfig = base_configs.ExportConfig()
runtime: base_configs.RuntimeConfig = base_configs.RuntimeConfig()
train_dataset: dataset_factory.DatasetConfig = dataset_factory.ImageNetConfig(
split='train', one_hot=False, mean_subtract=True, standardize=True)
validation_dataset: dataset_factory.DatasetConfig = dataset_factory.ImageNetConfig(
split='validation', one_hot=False, mean_subtract=True, standardize=True)
train: base_configs.TrainConfig = base_configs.TrainConfig(
resume_checkpoint=True,
epochs=90,
steps=None,
callbacks=base_configs.CallbacksConfig(
enable_checkpoint_and_export=True, enable_tensorboard=True),
metrics=['accuracy', 'top_5'],
time_history=base_configs.TimeHistoryConfig(log_steps=100),
tensorboard=base_configs.TensorBoardConfig(
track_lr=True, write_model_weights=False),
set_epoch_loop=False)
evaluation: base_configs.EvalConfig = base_configs.EvalConfig(
epochs_between_evals=1, steps=None)
model: base_configs.ModelConfig = vgg_config.VGGModelConfig()
def get_config(model: str, dataset: str) -> base_configs.ExperimentConfig:
"""Given model and dataset names, return the ExperimentConfig."""
dataset_model_config_map = {
'imagenet': {
'efficientnet': EfficientNetImageNetConfig(),
'resnet': ResNetImagenetConfig(),
'vgg': VGGImagenetConfig(),
}
}
try:
......
# Training configuration for VGG-16 trained on ImageNet on GPUs.
# Reaches > 72.8% within 90 epochs.
# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
runtime:
distribution_strategy: 'mirrored'
num_gpus: 1
batchnorm_spatial_persistent: true
train_dataset:
name: 'imagenet2012'
data_dir: null
builder: 'records'
split: 'train'
image_size: 224
num_classes: 1000
num_examples: 1281167
batch_size: 128
use_per_replica_batch_size: true
dtype: 'float32'
mean_subtract: true
standardize: true
validation_dataset:
name: 'imagenet2012'
data_dir: null
builder: 'records'
split: 'validation'
image_size: 224
num_classes: 1000
num_examples: 50000
batch_size: 128
use_per_replica_batch_size: true
dtype: 'float32'
mean_subtract: true
standardize: true
model:
name: 'vgg'
optimizer:
name: 'momentum'
momentum: 0.9
epsilon: 0.001
loss:
label_smoothing: 0.0
train:
resume_checkpoint: true
epochs: 90
evaluation:
epochs_between_evals: 1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Dataset utilities for vision tasks using TFDS and tf.data.Dataset."""
from __future__ import absolute_import
from __future__ import division
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Configuration definitions for EfficientNet losses, learning rates, and optimizers."""
from __future__ import absolute_import
from __future__ import division
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment