Commit 8edebf5d authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

TFDS support for classification datasets.

Imagenet, and CIFAR.

PiperOrigin-RevId: 355665447
parent b9c9a142
# ResNet-50 ImageNet classification. 78.1% top-1 and 93.9% top-5 accuracy.
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
task:
model:
num_classes: 1001
input_size: [224, 224, 3]
backbone:
type: 'resnet'
resnet:
model_id: 50
norm_activation:
activation: 'swish'
losses:
l2_weight_decay: 0.0001
one_hot: true
label_smoothing: 0.1
train_data:
input_path: ''
tfds_name: 'imagenet2012'
tfds_split: 'train'
sharding: true
is_training: true
global_batch_size: 4096
dtype: 'bfloat16'
validation_data:
input_path: ''
tfds_name: 'imagenet2012'
tfds_split: 'validation'
sharding: true
is_training: false
global_batch_size: 4096
dtype: 'bfloat16'
drop_remainder: false
trainer:
train_steps: 62400
validation_steps: 13
validation_interval: 312
steps_per_loop: 312
summary_interval: 312
checkpoint_interval: 312
optimizer_config:
optimizer:
type: 'sgd'
sgd:
momentum: 0.9
learning_rate:
type: 'cosine'
cosine:
initial_learning_rate: 1.6
decay_steps: 62400
warmup:
type: 'linear'
linear:
warmup_steps: 1560
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TFDS Classification decoders."""
import tensorflow as tf
from official.vision.beta.dataloaders import decoder
class ClassificationDecorder(decoder.Decoder):
"""A tf.Example decoder for tfds classification datasets."""
def decode(self, serialized_example):
sample_dict = {
'image/encoded':
tf.io.encode_jpeg(serialized_example['image'], quality=100),
'image/class/label':
serialized_example['label'],
}
return sample_dict
TFDS_ID_TO_DECODER_MAP = {
'cifar10': ClassificationDecorder,
'cifar100': ClassificationDecorder,
'imagenet2012': ClassificationDecorder,
}
...@@ -23,6 +23,7 @@ from official.core import task_factory ...@@ -23,6 +23,7 @@ from official.core import task_factory
from official.modeling import tf_utils from official.modeling import tf_utils
from official.vision.beta.configs import image_classification as exp_cfg from official.vision.beta.configs import image_classification as exp_cfg
from official.vision.beta.dataloaders import classification_input from official.vision.beta.dataloaders import classification_input
from official.vision.beta.dataloaders import tfds_classification_decoders
from official.vision.beta.modeling import factory from official.vision.beta.modeling import factory
...@@ -78,7 +79,15 @@ class ImageClassificationTask(base_task.Task): ...@@ -78,7 +79,15 @@ class ImageClassificationTask(base_task.Task):
num_classes = self.task_config.model.num_classes num_classes = self.task_config.model.num_classes
input_size = self.task_config.model.input_size input_size = self.task_config.model.input_size
decoder = classification_input.Decoder() if params.tfds_name:
if params.tfds_name in tfds_classification_decoders.TFDS_ID_TO_DECODER_MAP:
decoder = tfds_classification_decoders.TFDS_ID_TO_DECODER_MAP[
params.tfds_name]()
else:
raise ValueError('TFDS {} is not supported'.format(params.tfds_name))
else:
decoder = classification_input.Decoder()
parser = classification_input.Parser( parser = classification_input.Parser(
output_size=input_size[:2], output_size=input_size[:2],
num_classes=num_classes, num_classes=num_classes,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment