Description: Use KerasCV to train an image classifier using modern best practices
"""
importmath
importsys
importtensorflowastf
fromabslimportflags
fromtensorflowimportkeras
fromtensorflow.kerasimportcallbacks
fromtensorflow.kerasimportlosses
fromtensorflow.kerasimportmetrics
fromtensorflow.kerasimportoptimizers
importkeras_cv
fromkeras_cvimportmodels
fromkeras_cv.datasetsimportimagenet
"""
## Overview
KerasCV makes training state-of-the-art classification models easy by providing implementations of modern models, preprocessing techniques, and layers.
In this tutorial, we walk through training a model against the Imagenet dataset using Keras and KerasCV.
This tutorial requires you to have KerasCV installed:
```shell
pip install keras-cv
```
"""
"""
## Setup, constants and flags
"""
flags.DEFINE_string(
"model_name",None,"The name of the model in KerasCV.models to use."
)
flags.DEFINE_string("imagenet_path",None,"Directory from which to load Imagenet.")
flags.DEFINE_string(
"backup_path",None,"Directory which will be used for training backups."
)
flags.DEFINE_string(
"weights_path",None,"Directory which will be used to store weight checkpoints."
)
flags.DEFINE_string(
"tensorboard_path",None,"Directory which will be used to store tensorboard logs."
)
flags.DEFINE_integer(
"batch_size",
128,
"Batch size for training and evaluation. This will be multiplied by the number of accelerators in use.",
)
flags.DEFINE_boolean(
"use_xla",True,"Whether or not to use XLA (jit_compile) for training."
)
flags.DEFINE_boolean(
"use_mixed_precision",
False,
"Whether or not to use FP16 mixed precision for training.",
)
flags.DEFINE_float(
"initial_learning_rate",
0.05,
"Initial learning rate which will reduce on plateau. This will be multiplied by the number of accelerators in use",
)
flags.DEFINE_string(
"model_kwargs",
"{}",
"Keyword argument dictionary to pass to the constructor of the model being trained",
)
flags.DEFINE_string(
"learning_rate_schedule",
"ReduceOnPlateau",
"String denoting the type of learning rate schedule to be used",
)
flags.DEFINE_float(
"warmup_steps_percentage",
0.1,
"For how many steps expressed in percentage (0..1 float) of total steps should the schedule warm up if we're using the warmup schedule",
)
flags.DEFINE_float(
"warmup_hold_steps_percentage",
0.1,
"For how many steps expressed in percentage (0..1 float) of total steps should the schedule hold the initial learning rate after warmup is finished, and before applying cosine decay.",
)
# An upper bound for number of epochs (this script uses EarlyStopping).
flags.DEFINE_integer("epochs",1000,"Epochs to train for")
FLAGS=flags.FLAGS
FLAGS(sys.argv)
CLASSES=1000
IMAGE_SIZE=(224,224)
REDUCE_ON_PLATEAU="ReduceOnPlateau"
COSINE_DECAY_WITH_WARMUP="CosineDecayWithWarmup"
ifFLAGS.model_namenotinmodels.__dict__:
raiseValueError(f"Invalid model name: {FLAGS.model_name}")