"vscode:/vscode.git/clone" did not exist on "169d5169fe4f805f39eef4a5b0aa2fe480190afe"
Commit b03da6ec authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Add in instructions for TFDS for classifier_trainer.

PiperOrigin-RevId: 303828114
parent fddab2eb
......@@ -19,11 +19,25 @@ installed and
### ImageNet preparation
#### Using TFDS
`classifier_trainer.py` supports ImageNet with
[TensorFlow Datasets (TFDS)](https://www.tensorflow.org/datasets/overview).
Please see the following [example snippet](https://github.com/tensorflow/datasets/blob/master/tensorflow_datasets/scripts/download_and_prepare.py)
for more information on how to use TFDS to download and prepare datasets, and
specifically the [TFDS ImageNet readme](https://github.com/tensorflow/datasets/blob/master/docs/catalog/imagenet2012.md)
for manual download instructions.
#### Legacy TFRecords
Download the ImageNet dataset and convert it to TFRecord format.
The following [script](https://github.com/tensorflow/tpu/blob/master/tools/datasets/imagenet_to_gcs.py)
and [README](https://github.com/tensorflow/tpu/tree/master/tools/datasets#imagenet_to_gcspy)
provide a few options.
Note that the legacy ResNet runners, e.g. [resnet/resnet_ctl_imagenet_main.py](resnet/resnet_ctl_imagenet_main.py)
require TFRecords whereas `classifier_trainer.py` can use both by setting the
builder to 'records' or 'tfds' in the configurations.
### Running on Cloud TPUs
Note: These models will **not** work with TPUs on Colab.
......@@ -114,7 +128,7 @@ python3 classifier_trainer.py \
--tpu=$TPU_NAME \
--model_dir=$MODEL_DIR \
--data_dir=$DATA_DIR \
--config_file=config/examples/resnet/imagenet/tpu.yaml
--config_file=configs/examples/resnet/imagenet/tpu.yaml
```
### EfficientNet
......@@ -141,7 +155,7 @@ python3 classifier_trainer.py \
--tpu=$TPU_NAME \
--model_dir=$MODEL_DIR \
--data_dir=$DATA_DIR \
--config_file=config/examples/efficientnet/imagenet/efficientnet-b0-tpu.yaml
--config_file=configs/examples/efficientnet/imagenet/efficientnet-b0-tpu.yaml
```
Note that the number of GPU devices can be overridden in the command line using
......
......@@ -99,7 +99,7 @@ def _get_dataset_builders(params: base_configs.ExperimentConfig,
image_size = get_image_size_from_model(params)
dataset_configs = [
params.train_dataset, params.validation_dataset, params.test_dataset
params.train_dataset, params.validation_dataset
]
builders = []
......@@ -154,9 +154,6 @@ def _get_params_from_flags(flags_obj: flags.FlagValues):
'validation_dataset': {
'data_dir': flags_obj.data_dir,
},
'test_dataset': {
'data_dir': flags_obj.data_dir,
},
}
overriding_configs = (flags_obj.config_file,
......@@ -300,8 +297,8 @@ def train_and_eval(
datasets = [builder.build() if builder else None for builder in builders]
# Unpack datasets and builders based on train/val/test splits
train_builder, validation_builder, test_builder = builders # pylint: disable=unbalanced-tuple-unpacking
train_dataset, validation_dataset, test_dataset = datasets
train_builder, validation_builder = builders # pylint: disable=unbalanced-tuple-unpacking
train_dataset, validation_dataset = datasets
train_epochs = params.train.epochs
train_steps = params.train.steps or train_builder.num_steps
......
......@@ -82,12 +82,6 @@ def basic_params_override() -> MutableMapping[str, Any]:
'use_per_replica_batch_size': True,
'image_size': 224,
},
'test_dataset': {
'builder': 'synthetic',
'batch_size': 1,
'use_per_replica_batch_size': True,
'image_size': 224,
},
'train': {
'steps': 1,
'epochs': 1,
......
......@@ -216,7 +216,6 @@ class ExperimentConfig(base_config.Config):
runtime: RuntimeConfig = None
train_dataset: Any = None
validation_dataset: Any = None
test_dataset: Any = None
train: TrainConfig = None
evaluation: EvalConfig = None
model: ModelConfig = None
......
......@@ -45,8 +45,6 @@ class EfficientNetImageNetConfig(base_configs.ExperimentConfig):
dataset_factory.ImageNetConfig(split='train')
validation_dataset: dataset_factory.DatasetConfig = \
dataset_factory.ImageNetConfig(split='validation')
test_dataset: dataset_factory.DatasetConfig = \
dataset_factory.ImageNetConfig(split='validation')
train: base_configs.TrainConfig = base_configs.TrainConfig(
resume_checkpoint=True,
epochs=500,
......@@ -78,11 +76,6 @@ class ResNetImagenetConfig(base_configs.ExperimentConfig):
one_hot=False,
mean_subtract=True,
standardize=True)
test_dataset: dataset_factory.DatasetConfig = \
dataset_factory.ImageNetConfig(split='validation',
one_hot=False,
mean_subtract=True,
standardize=True)
train: base_configs.TrainConfig = base_configs.TrainConfig(
resume_checkpoint=True,
epochs=90,
......
# Training configuration for ResNet trained on ImageNet on GPUs.
# Takes ~3 minutes, 15 seconds per epoch for 8 V100s.
# Reaches ~76.1% within 90 epochs.
# Reaches > 76.1% within 90 epochs.
# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
runtime:
model_dir: null
......@@ -10,7 +9,7 @@ runtime:
train_dataset:
name: 'imagenet2012'
data_dir: null
builder: 'records'
builder: 'tfds'
split: 'train'
image_size: 224
num_classes: 1000
......@@ -23,7 +22,7 @@ train_dataset:
validation_dataset:
name: 'imagenet2012'
data_dir: null
builder: 'records'
builder: 'tfds'
split: 'validation'
image_size: 224
num_classes: 1000
......
# Training configuration for ResNet trained on ImageNet on TPUs.
# Takes ~2 minutes, 43 seconds per epoch for a v3-32.
# Reaches ~76.1% within 90 epochs.
# Takes ~4 minutes, 30 seconds seconds per epoch for a v3-32.
# Reaches > 76.1% within 90 epochs.
# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
runtime:
model_dir: null
......@@ -9,7 +9,7 @@ runtime:
train_dataset:
name: 'imagenet2012'
data_dir: null
builder: 'records'
builder: 'tfds'
split: 'train'
one_hot: False
image_size: 224
......@@ -23,7 +23,7 @@ train_dataset:
validation_dataset:
name: 'imagenet2012'
data_dir: null
builder: 'records'
builder: 'tfds'
split: 'validation'
one_hot: False
image_size: 224
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment