Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
965cc3ee
Unverified
Commit
965cc3ee
authored
Apr 21, 2020
by
Ayushman Kumar
Committed by
GitHub
Apr 21, 2020
Browse files
Merge pull request #7 from tensorflow/master
updated
parents
1f3247f4
1f685c54
Changes
222
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
601 additions
and
213 deletions
+601
-213
official/vision/image_classification/augment.py
official/vision/image_classification/augment.py
+11
-14
official/vision/image_classification/augment_test.py
official/vision/image_classification/augment_test.py
+11
-5
official/vision/image_classification/callbacks.py
official/vision/image_classification/callbacks.py
+145
-23
official/vision/image_classification/classifier_trainer.py
official/vision/image_classification/classifier_trainer.py
+83
-54
official/vision/image_classification/classifier_trainer_test.py
...al/vision/image_classification/classifier_trainer_test.py
+86
-17
official/vision/image_classification/configs/base_configs.py
official/vision/image_classification/configs/base_configs.py
+21
-5
official/vision/image_classification/configs/configs.py
official/vision/image_classification/configs/configs.py
+6
-9
official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b0-gpu.yaml
...s/examples/efficientnet/imagenet/efficientnet-b0-gpu.yaml
+3
-2
official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b0-tpu.yaml
...s/examples/efficientnet/imagenet/efficientnet-b0-tpu.yaml
+2
-3
official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b1-gpu.yaml
...s/examples/efficientnet/imagenet/efficientnet-b1-gpu.yaml
+5
-2
official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b1-tpu.yaml
...s/examples/efficientnet/imagenet/efficientnet-b1-tpu.yaml
+3
-2
official/vision/image_classification/configs/examples/resnet/imagenet/gpu.yaml
..._classification/configs/examples/resnet/imagenet/gpu.yaml
+4
-7
official/vision/image_classification/configs/examples/resnet/imagenet/tpu.yaml
..._classification/configs/examples/resnet/imagenet/tpu.yaml
+5
-7
official/vision/image_classification/dataset_factory.py
official/vision/image_classification/dataset_factory.py
+90
-29
official/vision/image_classification/efficientnet/common_modules.py
...ision/image_classification/efficientnet/common_modules.py
+20
-3
official/vision/image_classification/efficientnet/efficientnet_config.py
.../image_classification/efficientnet/efficientnet_config.py
+17
-14
official/vision/image_classification/efficientnet/efficientnet_model.py
...n/image_classification/efficientnet/efficientnet_model.py
+15
-14
official/vision/image_classification/efficientnet/tfhub_export.py
.../vision/image_classification/efficientnet/tfhub_export.py
+72
-0
official/vision/image_classification/learning_rate.py
official/vision/image_classification/learning_rate.py
+1
-1
official/vision/image_classification/learning_rate_test.py
official/vision/image_classification/learning_rate_test.py
+1
-2
No files found.
official/vision/image_classification/augment.py
View file @
965cc3ee
...
@@ -24,8 +24,8 @@ from __future__ import division
...
@@ -24,8 +24,8 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
math
import
math
import
tensorflow
.compat.v2
as
tf
import
tensorflow
as
tf
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Text
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Text
,
Tuple
from
tensorflow.python.keras.layers.preprocessing
import
image_preprocessing
as
image_ops
from
tensorflow.python.keras.layers.preprocessing
import
image_preprocessing
as
image_ops
...
@@ -66,7 +66,7 @@ def to_4d(image: tf.Tensor) -> tf.Tensor:
...
@@ -66,7 +66,7 @@ def to_4d(image: tf.Tensor) -> tf.Tensor:
return
tf
.
reshape
(
image
,
new_shape
)
return
tf
.
reshape
(
image
,
new_shape
)
def
from_4d
(
image
:
tf
.
Tensor
,
ndims
:
int
)
->
tf
.
Tensor
:
def
from_4d
(
image
:
tf
.
Tensor
,
ndims
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Converts a 4D image back to `ndims` rank."""
"""Converts a 4D image back to `ndims` rank."""
shape
=
tf
.
shape
(
image
)
shape
=
tf
.
shape
(
image
)
begin
=
tf
.
cast
(
tf
.
less_equal
(
ndims
,
3
),
dtype
=
tf
.
int32
)
begin
=
tf
.
cast
(
tf
.
less_equal
(
ndims
,
3
),
dtype
=
tf
.
int32
)
...
@@ -75,8 +75,7 @@ def from_4d(image: tf.Tensor, ndims: int) -> tf.Tensor:
...
@@ -75,8 +75,7 @@ def from_4d(image: tf.Tensor, ndims: int) -> tf.Tensor:
return
tf
.
reshape
(
image
,
new_shape
)
return
tf
.
reshape
(
image
,
new_shape
)
def
_convert_translation_to_transform
(
def
_convert_translation_to_transform
(
translations
:
tf
.
Tensor
)
->
tf
.
Tensor
:
translations
:
Iterable
[
int
])
->
tf
.
Tensor
:
"""Converts translations to a projective transform.
"""Converts translations to a projective transform.
The translation matrix looks like this:
The translation matrix looks like this:
...
@@ -122,9 +121,9 @@ def _convert_translation_to_transform(
...
@@ -122,9 +121,9 @@ def _convert_translation_to_transform(
def
_convert_angles_to_transform
(
def
_convert_angles_to_transform
(
angles
:
Union
[
Iterable
[
float
],
float
]
,
angles
:
tf
.
Tensor
,
image_width
:
int
,
image_width
:
tf
.
Tensor
,
image_height
:
int
)
->
tf
.
Tensor
:
image_height
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Converts an angle or angles to a projective transform.
"""Converts an angle or angles to a projective transform.
Args:
Args:
...
@@ -166,8 +165,7 @@ def _convert_angles_to_transform(
...
@@ -166,8 +165,7 @@ def _convert_angles_to_transform(
)
)
def
transform
(
image
:
tf
.
Tensor
,
def
transform
(
image
:
tf
.
Tensor
,
transforms
)
->
tf
.
Tensor
:
transforms
:
Iterable
[
float
])
->
tf
.
Tensor
:
"""Prepares input data for `image_ops.transform`."""
"""Prepares input data for `image_ops.transform`."""
original_ndims
=
tf
.
rank
(
image
)
original_ndims
=
tf
.
rank
(
image
)
transforms
=
tf
.
convert_to_tensor
(
transforms
,
dtype
=
tf
.
float32
)
transforms
=
tf
.
convert_to_tensor
(
transforms
,
dtype
=
tf
.
float32
)
...
@@ -181,8 +179,7 @@ def transform(image: tf.Tensor,
...
@@ -181,8 +179,7 @@ def transform(image: tf.Tensor,
return
from_4d
(
image
,
original_ndims
)
return
from_4d
(
image
,
original_ndims
)
def
translate
(
image
:
tf
.
Tensor
,
def
translate
(
image
:
tf
.
Tensor
,
translations
)
->
tf
.
Tensor
:
translations
:
Iterable
[
int
])
->
tf
.
Tensor
:
"""Translates image(s) by provided vectors.
"""Translates image(s) by provided vectors.
Args:
Args:
...
@@ -212,7 +209,7 @@ def rotate(image: tf.Tensor, degrees: float) -> tf.Tensor:
...
@@ -212,7 +209,7 @@ def rotate(image: tf.Tensor, degrees: float) -> tf.Tensor:
"""
"""
# Convert from degrees to radians.
# Convert from degrees to radians.
degrees_to_radians
=
math
.
pi
/
180.0
degrees_to_radians
=
math
.
pi
/
180.0
radians
=
degrees
*
degrees_to_radians
radians
=
tf
.
cast
(
degrees
*
degrees_to_radians
,
tf
.
float32
)
original_ndims
=
tf
.
rank
(
image
)
original_ndims
=
tf
.
rank
(
image
)
image
=
to_4d
(
image
)
image
=
to_4d
(
image
)
...
@@ -577,7 +574,7 @@ def unwrap(image: tf.Tensor, replace: int) -> tf.Tensor:
...
@@ -577,7 +574,7 @@ def unwrap(image: tf.Tensor, replace: int) -> tf.Tensor:
return
image
return
image
def
_randomly_negate_tensor
(
tensor
:
tf
.
Tensor
)
->
tf
.
Tensor
:
def
_randomly_negate_tensor
(
tensor
)
:
"""With 50% prob turn the tensor negative."""
"""With 50% prob turn the tensor negative."""
should_flip
=
tf
.
cast
(
tf
.
floor
(
tf
.
random
.
uniform
([])
+
0.5
),
tf
.
bool
)
should_flip
=
tf
.
cast
(
tf
.
floor
(
tf
.
random
.
uniform
([])
+
0.5
),
tf
.
bool
)
final_tensor
=
tf
.
cond
(
should_flip
,
lambda
:
tensor
,
lambda
:
-
tensor
)
final_tensor
=
tf
.
cond
(
should_flip
,
lambda
:
tensor
,
lambda
:
-
tensor
)
...
...
official/vision/image_classification/augment_test.py
View file @
965cc3ee
...
@@ -21,7 +21,7 @@ from __future__ import print_function
...
@@ -21,7 +21,7 @@ from __future__ import print_function
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
import
tensorflow
.compat.v2
as
tf
import
tensorflow
as
tf
from
official.vision.image_classification
import
augment
from
official.vision.image_classification
import
augment
...
@@ -52,14 +52,21 @@ class TransformsTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -52,14 +52,21 @@ class TransformsTest(parameterized.TestCase, tf.test.TestCase):
self
.
assertAllEqual
(
augment
.
transform
(
image
,
transforms
=
[
1
]
*
8
),
self
.
assertAllEqual
(
augment
.
transform
(
image
,
transforms
=
[
1
]
*
8
),
[[
4
,
4
],
[
4
,
4
]])
[[
4
,
4
],
[
4
,
4
]])
def
disable_
test_translate
(
self
,
dtype
):
def
test_translate
(
self
,
dtype
):
image
=
tf
.
constant
(
image
=
tf
.
constant
(
[[
1
,
0
,
1
,
0
],
[
0
,
1
,
0
,
1
],
[
1
,
0
,
1
,
0
],
[
0
,
1
,
0
,
1
]],
[[
1
,
0
,
1
,
0
],
[
0
,
1
,
0
,
1
],
[
1
,
0
,
1
,
0
],
[
0
,
1
,
0
,
1
]],
dtype
=
dtype
)
dtype
=
dtype
)
translations
=
[
-
1
,
-
1
]
translations
=
[
-
1
,
-
1
]
translated
=
augment
.
translate
(
image
=
image
,
translated
=
augment
.
translate
(
image
=
image
,
translations
=
translations
)
translations
=
translations
)
expected
=
[[
1
,
0
,
1
,
0
],
[
0
,
1
,
0
,
0
],
[
1
,
0
,
1
,
0
],
[
0
,
0
,
0
,
0
]]
expected
=
[
[
1
,
0
,
1
,
1
],
[
0
,
1
,
0
,
0
],
[
1
,
0
,
1
,
1
],
[
1
,
0
,
1
,
1
]]
self
.
assertAllEqual
(
translated
,
expected
)
self
.
assertAllEqual
(
translated
,
expected
)
def
test_translate_shapes
(
self
,
dtype
):
def
test_translate_shapes
(
self
,
dtype
):
...
@@ -133,5 +140,4 @@ class AutoaugmentTest(tf.test.TestCase):
...
@@ -133,5 +140,4 @@ class AutoaugmentTest(tf.test.TestCase):
self
.
assertEqual
((
224
,
224
,
3
),
image
.
shape
)
self
.
assertEqual
((
224
,
224
,
3
),
image
.
shape
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
assert
tf
.
version
.
VERSION
.
startswith
(
'2.'
)
tf
.
test
.
main
()
tf
.
test
.
main
()
official/vision/image_classification/callbacks.py
View file @
965cc3ee
# Lint as: python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
...
@@ -19,18 +20,24 @@ from __future__ import division
...
@@ -19,18 +20,24 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
os
import
os
from
typing
import
Any
,
List
,
MutableMapping
,
Text
from
absl
import
logging
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
from
typing
import
Any
,
List
,
MutableMapping
,
Text
from
official.utils.misc
import
keras_utils
from
official.vision.image_classification
import
optimizer_factory
def
get_callbacks
(
model_checkpoint
:
bool
=
True
,
def
get_callbacks
(
model_checkpoint
:
bool
=
True
,
include_tensorboard
:
bool
=
True
,
include_tensorboard
:
bool
=
True
,
time_history
:
bool
=
True
,
track_lr
:
bool
=
True
,
track_lr
:
bool
=
True
,
write_model_weights
:
bool
=
True
,
write_model_weights
:
bool
=
True
,
apply_moving_average
:
bool
=
False
,
initial_step
:
int
=
0
,
initial_step
:
int
=
0
,
model_dir
:
Text
=
None
)
->
List
[
tf
.
keras
.
callbacks
.
Callback
]:
batch_size
:
int
=
0
,
log_steps
:
int
=
0
,
model_dir
:
str
=
None
)
->
List
[
tf
.
keras
.
callbacks
.
Callback
]:
"""Get all callbacks."""
"""Get all callbacks."""
model_dir
=
model_dir
or
''
model_dir
=
model_dir
or
''
callbacks
=
[]
callbacks
=
[]
...
@@ -39,11 +46,29 @@ def get_callbacks(model_checkpoint: bool = True,
...
@@ -39,11 +46,29 @@ def get_callbacks(model_checkpoint: bool = True,
callbacks
.
append
(
tf
.
keras
.
callbacks
.
ModelCheckpoint
(
callbacks
.
append
(
tf
.
keras
.
callbacks
.
ModelCheckpoint
(
ckpt_full_path
,
save_weights_only
=
True
,
verbose
=
1
))
ckpt_full_path
,
save_weights_only
=
True
,
verbose
=
1
))
if
include_tensorboard
:
if
include_tensorboard
:
callbacks
.
append
(
CustomTensorBoard
(
callbacks
.
append
(
log_dir
=
model_dir
,
CustomTensorBoard
(
track_lr
=
track_lr
,
log_dir
=
model_dir
,
initial_step
=
initial_step
,
track_lr
=
track_lr
,
write_images
=
write_model_weights
))
initial_step
=
initial_step
,
write_images
=
write_model_weights
))
if
time_history
:
callbacks
.
append
(
keras_utils
.
TimeHistory
(
batch_size
,
log_steps
,
logdir
=
model_dir
if
include_tensorboard
else
None
))
if
apply_moving_average
:
# Save moving average model to a different file so that
# we can resume training from a checkpoint
ckpt_full_path
=
os
.
path
.
join
(
model_dir
,
'average'
,
'model.ckpt-{epoch:04d}'
)
callbacks
.
append
(
AverageModelCheckpoint
(
update_weights
=
False
,
filepath
=
ckpt_full_path
,
save_weights_only
=
True
,
verbose
=
1
))
callbacks
.
append
(
MovingAverageCallback
())
return
callbacks
return
callbacks
...
@@ -63,18 +88,19 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
...
@@ -63,18 +88,19 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
- Global learning rate
- Global learning rate
Attributes:
Attributes:
log_dir: the path of the directory where to save the log files to be
log_dir: the path of the directory where to save the log files to be
parsed
parsed
by TensorBoard.
by TensorBoard.
track_lr: `bool`, whether or not to track the global learning rate.
track_lr: `bool`, whether or not to track the global learning rate.
initial_step: the initial step, used for preemption recovery.
initial_step: the initial step, used for preemption recovery.
**kwargs: Additional arguments for backwards compatibility. Possible key
**kwargs: Additional arguments for backwards compatibility. Possible key
is
is
`period`.
`period`.
"""
"""
# TODO(b/146499062): track params, flops, log lr, l2 loss,
# TODO(b/146499062): track params, flops, log lr, l2 loss,
# classification loss
# classification loss
def
__init__
(
self
,
def
__init__
(
self
,
log_dir
:
Text
,
log_dir
:
str
,
track_lr
:
bool
=
False
,
track_lr
:
bool
=
False
,
initial_step
:
int
=
0
,
initial_step
:
int
=
0
,
**
kwargs
):
**
kwargs
):
...
@@ -84,7 +110,7 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
...
@@ -84,7 +110,7 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
def
on_batch_begin
(
self
,
def
on_batch_begin
(
self
,
epoch
:
int
,
epoch
:
int
,
logs
:
MutableMapping
[
Text
,
Any
]
=
None
)
->
None
:
logs
:
MutableMapping
[
str
,
Any
]
=
None
)
->
None
:
self
.
step
+=
1
self
.
step
+=
1
if
logs
is
None
:
if
logs
is
None
:
logs
=
{}
logs
=
{}
...
@@ -93,7 +119,7 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
...
@@ -93,7 +119,7 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
def
on_epoch_begin
(
self
,
def
on_epoch_begin
(
self
,
epoch
:
int
,
epoch
:
int
,
logs
:
MutableMapping
[
Text
,
Any
]
=
None
)
->
None
:
logs
:
MutableMapping
[
str
,
Any
]
=
None
)
->
None
:
if
logs
is
None
:
if
logs
is
None
:
logs
=
{}
logs
=
{}
metrics
=
self
.
_calculate_metrics
()
metrics
=
self
.
_calculate_metrics
()
...
@@ -104,25 +130,24 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
...
@@ -104,25 +130,24 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
def
on_epoch_end
(
self
,
def
on_epoch_end
(
self
,
epoch
:
int
,
epoch
:
int
,
logs
:
MutableMapping
[
Text
,
Any
]
=
None
)
->
None
:
logs
:
MutableMapping
[
str
,
Any
]
=
None
)
->
None
:
if
logs
is
None
:
if
logs
is
None
:
logs
=
{}
logs
=
{}
metrics
=
self
.
_calculate_metrics
()
metrics
=
self
.
_calculate_metrics
()
logs
.
update
(
metrics
)
logs
.
update
(
metrics
)
super
(
CustomTensorBoard
,
self
).
on_epoch_end
(
epoch
,
logs
)
super
(
CustomTensorBoard
,
self
).
on_epoch_end
(
epoch
,
logs
)
def
_calculate_metrics
(
self
)
->
MutableMapping
[
Text
,
Any
]:
def
_calculate_metrics
(
self
)
->
MutableMapping
[
str
,
Any
]:
logs
=
{}
logs
=
{}
if
self
.
_track_lr
:
# TODO(b/149030439): disable LR reporting.
logs
[
'learning_rate'
]
=
self
.
_calculate_lr
()
# if self._track_lr:
# logs['learning_rate'] = self._calculate_lr()
return
logs
return
logs
def
_calculate_lr
(
self
)
->
int
:
def
_calculate_lr
(
self
)
->
int
:
"""Calculates the learning rate given the current step."""
"""Calculates the learning rate given the current step."""
lr
=
self
.
_get_base_optimizer
().
lr
return
get_scalar_from_tensor
(
if
callable
(
lr
):
self
.
_get_base_optimizer
().
_decayed_lr
(
var_dtype
=
tf
.
float32
))
# pylint:disable=protected-access
lr
=
lr
(
self
.
step
)
return
get_scalar_from_tensor
(
lr
)
def
_get_base_optimizer
(
self
)
->
tf
.
keras
.
optimizers
.
Optimizer
:
def
_get_base_optimizer
(
self
)
->
tf
.
keras
.
optimizers
.
Optimizer
:
"""Get the base optimizer used by the current model."""
"""Get the base optimizer used by the current model."""
...
@@ -134,3 +159,100 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
...
@@ -134,3 +159,100 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
optimizer
=
optimizer
.
_optimizer
# pylint:disable=protected-access
optimizer
=
optimizer
.
_optimizer
# pylint:disable=protected-access
return
optimizer
return
optimizer
class
MovingAverageCallback
(
tf
.
keras
.
callbacks
.
Callback
):
"""A Callback to be used with a `MovingAverage` optimizer.
Applies moving average weights to the model during validation time to test
and predict on the averaged weights rather than the current model weights.
Once training is complete, the model weights will be overwritten with the
averaged weights (by default).
Attributes:
overwrite_weights_on_train_end: Whether to overwrite the current model
weights with the averaged weights from the moving average optimizer.
**kwargs: Any additional callback arguments.
"""
def
__init__
(
self
,
overwrite_weights_on_train_end
:
bool
=
False
,
**
kwargs
):
super
(
MovingAverageCallback
,
self
).
__init__
(
**
kwargs
)
self
.
overwrite_weights_on_train_end
=
overwrite_weights_on_train_end
def
set_model
(
self
,
model
:
tf
.
keras
.
Model
):
super
(
MovingAverageCallback
,
self
).
set_model
(
model
)
assert
isinstance
(
self
.
model
.
optimizer
,
optimizer_factory
.
MovingAverage
)
self
.
model
.
optimizer
.
shadow_copy
(
self
.
model
)
def
on_test_begin
(
self
,
logs
:
MutableMapping
[
Text
,
Any
]
=
None
):
self
.
model
.
optimizer
.
swap_weights
()
def
on_test_end
(
self
,
logs
:
MutableMapping
[
Text
,
Any
]
=
None
):
self
.
model
.
optimizer
.
swap_weights
()
def
on_train_end
(
self
,
logs
:
MutableMapping
[
Text
,
Any
]
=
None
):
if
self
.
overwrite_weights_on_train_end
:
self
.
model
.
optimizer
.
assign_average_vars
(
self
.
model
.
variables
)
class
AverageModelCheckpoint
(
tf
.
keras
.
callbacks
.
ModelCheckpoint
):
"""Saves and, optionally, assigns the averaged weights.
Taken from tfa.callbacks.AverageModelCheckpoint.
Attributes:
update_weights: If True, assign the moving average weights
to the model, and save them. If False, keep the old
non-averaged weights, but the saved model uses the
average weights.
See `tf.keras.callbacks.ModelCheckpoint` for the other args.
"""
def
__init__
(
self
,
update_weights
:
bool
,
filepath
:
str
,
monitor
:
str
=
'val_loss'
,
verbose
:
int
=
0
,
save_best_only
:
bool
=
False
,
save_weights_only
:
bool
=
False
,
mode
:
str
=
'auto'
,
save_freq
:
str
=
'epoch'
,
**
kwargs
):
self
.
update_weights
=
update_weights
super
().
__init__
(
filepath
,
monitor
,
verbose
,
save_best_only
,
save_weights_only
,
mode
,
save_freq
,
**
kwargs
)
def
set_model
(
self
,
model
):
if
not
isinstance
(
model
.
optimizer
,
optimizer_factory
.
MovingAverage
):
raise
TypeError
(
'AverageModelCheckpoint is only used when training'
'with MovingAverage'
)
return
super
().
set_model
(
model
)
def
_save_model
(
self
,
epoch
,
logs
):
assert
isinstance
(
self
.
model
.
optimizer
,
optimizer_factory
.
MovingAverage
)
if
self
.
update_weights
:
self
.
model
.
optimizer
.
assign_average_vars
(
self
.
model
.
variables
)
return
super
().
_save_model
(
epoch
,
logs
)
else
:
# Note: `model.get_weights()` gives us the weights (non-ref)
# whereas `model.variables` returns references to the variables.
non_avg_weights
=
self
.
model
.
get_weights
()
self
.
model
.
optimizer
.
assign_average_vars
(
self
.
model
.
variables
)
# result is currently None, since `super._save_model` doesn't
# return anything, but this may change in the future.
result
=
super
().
_save_model
(
epoch
,
logs
)
self
.
model
.
set_weights
(
non_avg_weights
)
return
result
official/vision/image_classification/classifier_trainer.py
View file @
965cc3ee
...
@@ -27,12 +27,11 @@ from typing import Any, Tuple, Text, Optional, Mapping
...
@@ -27,12 +27,11 @@ from typing import Any, Tuple, Text, Optional, Mapping
from
absl
import
app
from
absl
import
app
from
absl
import
flags
from
absl
import
flags
from
absl
import
logging
from
absl
import
logging
import
tensorflow
.compat.v2
as
tf
import
tensorflow
as
tf
from
official.modeling
import
performance
from
official.modeling
import
performance
from
official.modeling.hyperparams
import
params_dict
from
official.modeling.hyperparams
import
params_dict
from
official.utils
import
hyperparams_flags
from
official.utils
import
hyperparams_flags
from
official.utils.logs
import
logger
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
keras_utils
from
official.vision.image_classification
import
callbacks
as
custom_callbacks
from
official.vision.image_classification
import
callbacks
as
custom_callbacks
...
@@ -44,10 +43,24 @@ from official.vision.image_classification.efficientnet import efficientnet_model
...
@@ -44,10 +43,24 @@ from official.vision.image_classification.efficientnet import efficientnet_model
from
official.vision.image_classification.resnet
import
common
from
official.vision.image_classification.resnet
import
common
from
official.vision.image_classification.resnet
import
resnet_model
from
official.vision.image_classification.resnet
import
resnet_model
MODELS
=
{
'efficientnet'
:
efficientnet_model
.
EfficientNet
.
from_name
,
def
get_models
()
->
Mapping
[
str
,
tf
.
keras
.
Model
]:
'resnet'
:
resnet_model
.
resnet50
,
"""Returns the mapping from model type name to Keras model."""
}
return
{
'efficientnet'
:
efficientnet_model
.
EfficientNet
.
from_name
,
'resnet'
:
resnet_model
.
resnet50
,
}
def
get_dtype_map
()
->
Mapping
[
str
,
tf
.
dtypes
.
DType
]:
"""Returns the mapping from dtype string representations to TF dtypes."""
return
{
'float32'
:
tf
.
float32
,
'bfloat16'
:
tf
.
bfloat16
,
'float16'
:
tf
.
float16
,
'fp32'
:
tf
.
float32
,
'bf16'
:
tf
.
bfloat16
,
}
def
_get_metrics
(
one_hot
:
bool
)
->
Mapping
[
Text
,
Any
]:
def
_get_metrics
(
one_hot
:
bool
)
->
Mapping
[
Text
,
Any
]:
...
@@ -87,19 +100,20 @@ def get_image_size_from_model(
...
@@ -87,19 +100,20 @@ def get_image_size_from_model(
def
_get_dataset_builders
(
params
:
base_configs
.
ExperimentConfig
,
def
_get_dataset_builders
(
params
:
base_configs
.
ExperimentConfig
,
strategy
:
tf
.
distribute
.
Strategy
,
strategy
:
tf
.
distribute
.
Strategy
,
one_hot
:
bool
one_hot
:
bool
)
->
Tuple
[
Any
,
Any
,
Any
]:
)
->
Tuple
[
Any
,
Any
]:
"""Create and return train
,
validation
, and test
dataset builders."""
"""Create and return train
and
validation dataset builders."""
if
one_hot
:
if
one_hot
:
logging
.
warning
(
'label_smoothing > 0, so datasets will be one hot encoded.'
)
logging
.
warning
(
'label_smoothing > 0, so datasets will be one hot encoded.'
)
else
:
else
:
logging
.
warning
(
'label_smoothing not applied, so datasets will not be one '
logging
.
warning
(
'label_smoothing not applied, so datasets will not be one '
'hot encoded.'
)
'hot encoded.'
)
num_devices
=
strategy
.
num_replicas_in_sync
num_devices
=
strategy
.
num_replicas_in_sync
if
strategy
else
1
image_size
=
get_image_size_from_model
(
params
)
image_size
=
get_image_size_from_model
(
params
)
dataset_configs
=
[
dataset_configs
=
[
params
.
train_dataset
,
params
.
validation_dataset
,
params
.
test_dataset
params
.
train_dataset
,
params
.
validation_dataset
]
]
builders
=
[]
builders
=
[]
...
@@ -120,12 +134,13 @@ def _get_dataset_builders(params: base_configs.ExperimentConfig,
...
@@ -120,12 +134,13 @@ def _get_dataset_builders(params: base_configs.ExperimentConfig,
def
get_loss_scale
(
params
:
base_configs
.
ExperimentConfig
,
def
get_loss_scale
(
params
:
base_configs
.
ExperimentConfig
,
fp16_default
:
float
=
128.
)
->
float
:
fp16_default
:
float
=
128.
)
->
float
:
"""Returns the loss scale for initializations."""
"""Returns the loss scale for initializations."""
loss_scale
=
params
.
model
.
loss
.
loss_scale
loss_scale
=
params
.
runtime
.
loss_scale
if
loss_scale
==
'dynamic'
:
if
loss_scale
==
'dynamic'
:
return
loss_scale
return
loss_scale
elif
loss_scale
is
not
None
:
elif
loss_scale
is
not
None
:
return
float
(
loss_scale
)
return
float
(
loss_scale
)
elif
params
.
train_dataset
.
dtype
==
'float32'
:
elif
(
params
.
train_dataset
.
dtype
==
'float32'
or
params
.
train_dataset
.
dtype
==
'bfloat16'
):
return
1.
return
1.
else
:
else
:
assert
params
.
train_dataset
.
dtype
==
'float16'
assert
params
.
train_dataset
.
dtype
==
'float16'
...
@@ -145,7 +160,7 @@ def _get_params_from_flags(flags_obj: flags.FlagValues):
...
@@ -145,7 +160,7 @@ def _get_params_from_flags(flags_obj: flags.FlagValues):
'name'
:
model
,
'name'
:
model
,
},
},
'runtime'
:
{
'runtime'
:
{
'
enable
_eager'
:
flags_obj
.
enable
_eager
,
'
run
_eager
ly
'
:
flags_obj
.
run
_eager
ly
,
'tpu'
:
flags_obj
.
tpu
,
'tpu'
:
flags_obj
.
tpu
,
},
},
'train_dataset'
:
{
'train_dataset'
:
{
...
@@ -154,8 +169,10 @@ def _get_params_from_flags(flags_obj: flags.FlagValues):
...
@@ -154,8 +169,10 @@ def _get_params_from_flags(flags_obj: flags.FlagValues):
'validation_dataset'
:
{
'validation_dataset'
:
{
'data_dir'
:
flags_obj
.
data_dir
,
'data_dir'
:
flags_obj
.
data_dir
,
},
},
'test_dataset'
:
{
'train'
:
{
'data_dir'
:
flags_obj
.
data_dir
,
'time_history'
:
{
'log_steps'
:
flags_obj
.
log_steps
,
},
},
},
}
}
...
@@ -169,8 +186,7 @@ def _get_params_from_flags(flags_obj: flags.FlagValues):
...
@@ -169,8 +186,7 @@ def _get_params_from_flags(flags_obj: flags.FlagValues):
for
param
in
overriding_configs
:
for
param
in
overriding_configs
:
logging
.
info
(
'Overriding params: %s'
,
param
)
logging
.
info
(
'Overriding params: %s'
,
param
)
# Set is_strict to false because we can have dynamic dict parameters.
params
=
params_dict
.
override_params_dict
(
params
,
param
,
is_strict
=
True
)
params
=
params_dict
.
override_params_dict
(
params
,
param
,
is_strict
=
False
)
params
.
validate
()
params
.
validate
()
params
.
lock
()
params
.
lock
()
...
@@ -212,24 +228,21 @@ def resume_from_checkpoint(model: tf.keras.Model,
...
@@ -212,24 +228,21 @@ def resume_from_checkpoint(model: tf.keras.Model,
return
int
(
initial_epoch
)
return
int
(
initial_epoch
)
def
initialize
(
params
:
base_configs
.
ExperimentConfig
):
def
initialize
(
params
:
base_configs
.
ExperimentConfig
,
dataset_builder
:
dataset_factory
.
DatasetBuilder
):
"""Initializes backend related initializations."""
"""Initializes backend related initializations."""
keras_utils
.
set_session_config
(
keras_utils
.
set_session_config
(
enable_eager
=
params
.
runtime
.
enable_eager
,
enable_xla
=
params
.
runtime
.
enable_xla
)
enable_xla
=
params
.
runtime
.
enable_xla
)
if
params
.
runtime
.
gpu_thread
s_enabled
:
if
params
.
runtime
.
gpu_thread
_mode
:
keras_utils
.
set_gpu_thread_mode_and_count
(
keras_utils
.
set_gpu_thread_mode_and_count
(
per_gpu_thread_count
=
params
.
runtime
.
per_gpu_thread_count
,
per_gpu_thread_count
=
params
.
runtime
.
per_gpu_thread_count
,
gpu_thread_mode
=
params
.
runtime
.
gpu_thread_mode
,
gpu_thread_mode
=
params
.
runtime
.
gpu_thread_mode
,
num_gpus
=
params
.
runtime
.
num_gpus
,
num_gpus
=
params
.
runtime
.
num_gpus
,
datasets_num_private_threads
=
params
.
runtime
.
dataset_num_private_threads
)
datasets_num_private_threads
=
params
.
runtime
.
dataset_num_private_threads
)
dataset
=
params
.
train_dataset
or
params
.
validation_dataset
performance
.
set_mixed_precision_policy
(
dataset_builder
.
dtype
,
performance
.
set_mixed_precision_policy
(
dataset
.
dtype
)
get_loss_scale
(
params
))
if
tf
.
config
.
list_physical_devices
(
'GPU'
):
if
dataset
.
data_format
:
data_format
=
dataset
.
data_format
elif
tf
.
config
.
list_physical_devices
(
'GPU'
):
data_format
=
'channels_first'
data_format
=
'channels_first'
else
:
else
:
data_format
=
'channels_last'
data_format
=
'channels_last'
...
@@ -237,7 +250,7 @@ def initialize(params: base_configs.ExperimentConfig):
...
@@ -237,7 +250,7 @@ def initialize(params: base_configs.ExperimentConfig):
distribution_utils
.
configure_cluster
(
distribution_utils
.
configure_cluster
(
params
.
runtime
.
worker_hosts
,
params
.
runtime
.
worker_hosts
,
params
.
runtime
.
task_index
)
params
.
runtime
.
task_index
)
if
params
.
runtime
.
enable
_eager
:
if
params
.
runtime
.
run
_eager
ly
:
# Enable eager execution to allow step-by-step debugging
# Enable eager execution to allow step-by-step debugging
tf
.
config
.
experimental_run_functions_eagerly
(
True
)
tf
.
config
.
experimental_run_functions_eagerly
(
True
)
...
@@ -254,7 +267,7 @@ def define_classifier_flags():
...
@@ -254,7 +267,7 @@ def define_classifier_flags():
default
=
None
,
default
=
None
,
help
=
'Mode to run: `train`, `eval`, `train_and_eval` or `export`.'
)
help
=
'Mode to run: `train`, `eval`, `train_and_eval` or `export`.'
)
flags
.
DEFINE_bool
(
flags
.
DEFINE_bool
(
'
enable
_eager'
,
'
run
_eager
ly
'
,
default
=
None
,
default
=
None
,
help
=
'Use eager execution and disable autograph for debugging.'
)
help
=
'Use eager execution and disable autograph for debugging.'
)
flags
.
DEFINE_string
(
flags
.
DEFINE_string
(
...
@@ -265,6 +278,10 @@ def define_classifier_flags():
...
@@ -265,6 +278,10 @@ def define_classifier_flags():
'dataset'
,
'dataset'
,
default
=
None
,
default
=
None
,
help
=
'The name of the dataset, e.g. ImageNet, etc.'
)
help
=
'The name of the dataset, e.g. ImageNet, etc.'
)
flags
.
DEFINE_integer
(
'log_steps'
,
default
=
100
,
help
=
'The interval of steps between logging of batch level stats.'
)
def
serialize_config
(
params
:
base_configs
.
ExperimentConfig
,
def
serialize_config
(
params
:
base_configs
.
ExperimentConfig
,
...
@@ -291,27 +308,31 @@ def train_and_eval(
...
@@ -291,27 +308,31 @@ def train_and_eval(
strategy_scope
=
distribution_utils
.
get_strategy_scope
(
strategy
)
strategy_scope
=
distribution_utils
.
get_strategy_scope
(
strategy
)
logging
.
info
(
'Detected %d devices.'
,
strategy
.
num_replicas_in_sync
)
logging
.
info
(
'Detected %d devices.'
,
strategy
.
num_replicas_in_sync
if
strategy
else
1
)
label_smoothing
=
params
.
model
.
loss
.
label_smoothing
label_smoothing
=
params
.
model
.
loss
.
label_smoothing
one_hot
=
label_smoothing
and
label_smoothing
>
0
one_hot
=
label_smoothing
and
label_smoothing
>
0
builders
=
_get_dataset_builders
(
params
,
strategy
,
one_hot
)
builders
=
_get_dataset_builders
(
params
,
strategy
,
one_hot
)
datasets
=
[
builder
.
build
()
if
builder
else
None
for
builder
in
builders
]
datasets
=
[
builder
.
build
(
strategy
)
if
builder
else
None
for
builder
in
builders
]
# Unpack datasets and builders based on train/val/test splits
# Unpack datasets and builders based on train/val/test splits
train_builder
,
validation_builder
,
test_builder
=
builders
# pylint: disable=unbalanced-tuple-unpacking
train_builder
,
validation_builder
=
builders
# pylint: disable=unbalanced-tuple-unpacking
train_dataset
,
validation_dataset
,
test_dataset
=
datasets
train_dataset
,
validation_dataset
=
datasets
train_epochs
=
params
.
train
.
epochs
train_epochs
=
params
.
train
.
epochs
train_steps
=
params
.
train
.
steps
or
train_builder
.
num_steps
train_steps
=
params
.
train
.
steps
or
train_builder
.
num_steps
validation_steps
=
params
.
evaluation
.
steps
or
validation_builder
.
num_steps
validation_steps
=
params
.
evaluation
.
steps
or
validation_builder
.
num_steps
initialize
(
params
,
train_builder
)
logging
.
info
(
'Global batch size: %d'
,
train_builder
.
global_batch_size
)
logging
.
info
(
'Global batch size: %d'
,
train_builder
.
global_batch_size
)
with
strategy_scope
:
with
strategy_scope
:
model_params
=
params
.
model
.
model_params
.
as_dict
()
model_params
=
params
.
model
.
model_params
.
as_dict
()
model
=
MODELS
[
params
.
model
.
name
](
**
model_params
)
model
=
get_models
()
[
params
.
model
.
name
](
**
model_params
)
learning_rate
=
optimizer_factory
.
build_learning_rate
(
learning_rate
=
optimizer_factory
.
build_learning_rate
(
params
=
params
.
model
.
learning_rate
,
params
=
params
.
model
.
learning_rate
,
batch_size
=
train_builder
.
global_batch_size
,
batch_size
=
train_builder
.
global_batch_size
,
...
@@ -332,7 +353,7 @@ def train_and_eval(
...
@@ -332,7 +353,7 @@ def train_and_eval(
model
.
compile
(
optimizer
=
optimizer
,
model
.
compile
(
optimizer
=
optimizer
,
loss
=
loss_obj
,
loss
=
loss_obj
,
metrics
=
metrics
,
metrics
=
metrics
,
run_eagerly
=
params
.
runtime
.
enable_eager
)
experimental_steps_per_execution
=
params
.
train
.
steps_per_loop
)
initial_epoch
=
0
initial_epoch
=
0
if
params
.
train
.
resume_checkpoint
:
if
params
.
train
.
resume_checkpoint
:
...
@@ -340,15 +361,27 @@ def train_and_eval(
...
@@ -340,15 +361,27 @@ def train_and_eval(
model_dir
=
params
.
model_dir
,
model_dir
=
params
.
model_dir
,
train_steps
=
train_steps
)
train_steps
=
train_steps
)
callbacks
=
custom_callbacks
.
get_callbacks
(
model_checkpoint
=
params
.
train
.
callbacks
.
enable_checkpoint_and_export
,
include_tensorboard
=
params
.
train
.
callbacks
.
enable_tensorboard
,
time_history
=
params
.
train
.
callbacks
.
enable_time_history
,
track_lr
=
params
.
train
.
tensorboard
.
track_lr
,
write_model_weights
=
params
.
train
.
tensorboard
.
write_model_weights
,
initial_step
=
initial_epoch
*
train_steps
,
batch_size
=
train_builder
.
global_batch_size
,
log_steps
=
params
.
train
.
time_history
.
log_steps
,
model_dir
=
params
.
model_dir
)
serialize_config
(
params
=
params
,
model_dir
=
params
.
model_dir
)
serialize_config
(
params
=
params
,
model_dir
=
params
.
model_dir
)
# TODO(dankondratyuk): callbacks significantly slow down training
callbacks
=
custom_callbacks
.
get_callbacks
(
if
params
.
evaluation
.
skip_eval
:
model_checkpoint
=
params
.
train
.
callbacks
.
enable_checkpoint_and_export
,
validation_kwargs
=
{}
include_tensorboard
=
params
.
train
.
callbacks
.
enable_tensorboard
,
else
:
track_lr
=
params
.
train
.
tensorboard
.
track_lr
,
validation_kwargs
=
{
write_model_weights
=
params
.
train
.
tensorboard
.
write_model_weights
,
'validation_data'
:
validation_dataset
,
initial_step
=
initial_epoch
*
train_steps
,
'validation_steps'
:
validation_steps
,
model_dir
=
params
.
model_dir
)
'validation_freq'
:
params
.
evaluation
.
epochs_between_evals
,
}
history
=
model
.
fit
(
history
=
model
.
fit
(
train_dataset
,
train_dataset
,
...
@@ -356,15 +389,15 @@ def train_and_eval(
...
@@ -356,15 +389,15 @@ def train_and_eval(
steps_per_epoch
=
train_steps
,
steps_per_epoch
=
train_steps
,
initial_epoch
=
initial_epoch
,
initial_epoch
=
initial_epoch
,
callbacks
=
callbacks
,
callbacks
=
callbacks
,
validation_data
=
validation_dataset
,
verbose
=
2
,
validation_steps
=
validation_steps
,
**
validation_kwargs
)
validation_freq
=
params
.
evaluation
.
epochs_between_evals
)
validation_output
=
model
.
evaluate
(
validation_output
=
None
validation_dataset
,
steps
=
validation_steps
,
verbose
=
2
)
if
not
params
.
evaluation
.
skip_eval
:
validation_output
=
model
.
evaluate
(
validation_dataset
,
steps
=
validation_steps
,
verbose
=
2
)
# TODO(dankondratyuk): eval and save final test accuracy
# TODO(dankondratyuk): eval and save final test accuracy
stats
=
common
.
build_stats
(
history
,
stats
=
common
.
build_stats
(
history
,
validation_output
,
validation_output
,
callbacks
)
callbacks
)
...
@@ -375,7 +408,7 @@ def export(params: base_configs.ExperimentConfig):
...
@@ -375,7 +408,7 @@ def export(params: base_configs.ExperimentConfig):
"""Runs the model export functionality."""
"""Runs the model export functionality."""
logging
.
info
(
'Exporting model.'
)
logging
.
info
(
'Exporting model.'
)
model_params
=
params
.
model
.
model_params
.
as_dict
()
model_params
=
params
.
model
.
model_params
.
as_dict
()
model
=
MODELS
[
params
.
model
.
name
](
**
model_params
)
model
=
get_models
()
[
params
.
model
.
name
](
**
model_params
)
checkpoint
=
params
.
export
.
checkpoint
checkpoint
=
params
.
export
.
checkpoint
if
checkpoint
is
None
:
if
checkpoint
is
None
:
logging
.
info
(
'No export checkpoint was provided. Using the latest '
logging
.
info
(
'No export checkpoint was provided. Using the latest '
...
@@ -398,8 +431,6 @@ def run(flags_obj: flags.FlagValues,
...
@@ -398,8 +431,6 @@ def run(flags_obj: flags.FlagValues,
Dictionary of training/eval stats
Dictionary of training/eval stats
"""
"""
params
=
_get_params_from_flags
(
flags_obj
)
params
=
_get_params_from_flags
(
flags_obj
)
initialize
(
params
)
if
params
.
mode
==
'train_and_eval'
:
if
params
.
mode
==
'train_and_eval'
:
return
train_and_eval
(
params
,
strategy_override
)
return
train_and_eval
(
params
,
strategy_override
)
elif
params
.
mode
==
'export_only'
:
elif
params
.
mode
==
'export_only'
:
...
@@ -409,8 +440,7 @@ def run(flags_obj: flags.FlagValues,
...
@@ -409,8 +440,7 @@ def run(flags_obj: flags.FlagValues,
def
main
(
_
):
def
main
(
_
):
with
logger
.
benchmark_context
(
flags
.
FLAGS
):
stats
=
run
(
flags
.
FLAGS
)
stats
=
run
(
flags
.
FLAGS
)
if
stats
:
if
stats
:
logging
.
info
(
'Run stats:
\n
%s'
,
stats
)
logging
.
info
(
'Run stats:
\n
%s'
,
stats
)
...
@@ -423,5 +453,4 @@ if __name__ == '__main__':
...
@@ -423,5 +453,4 @@ if __name__ == '__main__':
flags
.
mark_flag_as_required
(
'model_type'
)
flags
.
mark_flag_as_required
(
'model_type'
)
flags
.
mark_flag_as_required
(
'dataset'
)
flags
.
mark_flag_as_required
(
'dataset'
)
assert
tf
.
version
.
VERSION
.
startswith
(
'2.'
)
app
.
run
(
main
)
app
.
run
(
main
)
official/vision/image_classification/classifier_trainer_test.py
View file @
965cc3ee
...
@@ -30,7 +30,7 @@ from typing import Any, Callable, Iterable, Mapping, MutableMapping, Optional, T
...
@@ -30,7 +30,7 @@ from typing import Any, Callable, Iterable, Mapping, MutableMapping, Optional, T
from
absl
import
flags
from
absl
import
flags
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
import
tensorflow
.compat.v2
as
tf
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
tensorflow.python.distribute
import
strategy_combinations
...
@@ -67,7 +67,7 @@ def get_params_override(params_override: Mapping[str, Any]) -> str:
...
@@ -67,7 +67,7 @@ def get_params_override(params_override: Mapping[str, Any]) -> str:
return
'--params_override='
+
json
.
dumps
(
params_override
)
return
'--params_override='
+
json
.
dumps
(
params_override
)
def
basic_params_override
()
->
MutableMapping
[
str
,
Any
]:
def
basic_params_override
(
dtype
:
str
=
'float32'
)
->
MutableMapping
[
str
,
Any
]:
"""Returns a basic parameter configuration for testing."""
"""Returns a basic parameter configuration for testing."""
return
{
return
{
'train_dataset'
:
{
'train_dataset'
:
{
...
@@ -75,18 +75,14 @@ def basic_params_override() -> MutableMapping[str, Any]:
...
@@ -75,18 +75,14 @@ def basic_params_override() -> MutableMapping[str, Any]:
'use_per_replica_batch_size'
:
True
,
'use_per_replica_batch_size'
:
True
,
'batch_size'
:
1
,
'batch_size'
:
1
,
'image_size'
:
224
,
'image_size'
:
224
,
'dtype'
:
dtype
,
},
},
'validation_dataset'
:
{
'validation_dataset'
:
{
'builder'
:
'synthetic'
,
'builder'
:
'synthetic'
,
'batch_size'
:
1
,
'batch_size'
:
1
,
'use_per_replica_batch_size'
:
True
,
'use_per_replica_batch_size'
:
True
,
'image_size'
:
224
,
'image_size'
:
224
,
},
'dtype'
:
dtype
,
'test_dataset'
:
{
'builder'
:
'synthetic'
,
'batch_size'
:
1
,
'use_per_replica_batch_size'
:
True
,
'image_size'
:
224
,
},
},
'train'
:
{
'train'
:
{
'steps'
:
1
,
'steps'
:
1
,
...
@@ -152,7 +148,7 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -152,7 +148,7 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
tf
.
io
.
gfile
.
rmtree
(
self
.
get_temp_dir
())
tf
.
io
.
gfile
.
rmtree
(
self
.
get_temp_dir
())
@
combinations
.
generate
(
distribution_strategy_combinations
())
@
combinations
.
generate
(
distribution_strategy_combinations
())
def
test_end_to_end_train_and_eval
_export
(
self
,
distribution
,
model
,
dataset
):
def
test_end_to_end_train_and_eval
(
self
,
distribution
,
model
,
dataset
):
"""Test train_and_eval and export for Keras classifier models."""
"""Test train_and_eval and export for Keras classifier models."""
# Some parameters are not defined as flags (e.g. cannot run
# Some parameters are not defined as flags (e.g. cannot run
# classifier_train.py --batch_size=...) by design, so use
# classifier_train.py --batch_size=...) by design, so use
...
@@ -168,6 +164,41 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -168,6 +164,41 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
'--mode=train_and_eval'
,
'--mode=train_and_eval'
,
]
]
run
=
functools
.
partial
(
classifier_trainer
.
run
,
strategy_override
=
distribution
)
run_end_to_end
(
main
=
run
,
extra_flags
=
train_and_eval_flags
,
model_dir
=
model_dir
)
@
combinations
.
generate
(
combinations
.
combine
(
distribution
=
[
strategy_combinations
.
one_device_strategy_gpu
,
],
model
=
[
'efficientnet'
,
'resnet'
,
],
mode
=
'eager'
,
dataset
=
'imagenet'
,
dtype
=
'float16'
,
))
def
test_gpu_train
(
self
,
distribution
,
model
,
dataset
,
dtype
):
"""Test train_and_eval and export for Keras classifier models."""
# Some parameters are not defined as flags (e.g. cannot run
# classifier_train.py --batch_size=...) by design, so use
# "--params_override=..." instead
model_dir
=
self
.
get_temp_dir
()
base_flags
=
[
'--data_dir=not_used'
,
'--model_type='
+
model
,
'--dataset='
+
dataset
,
]
train_and_eval_flags
=
base_flags
+
[
get_params_override
(
basic_params_override
(
dtype
)),
'--mode=train_and_eval'
,
]
export_params
=
basic_params_override
()
export_params
=
basic_params_override
()
export_path
=
os
.
path
.
join
(
model_dir
,
'export'
)
export_path
=
os
.
path
.
join
(
model_dir
,
'export'
)
export_params
[
'export'
]
=
{}
export_params
[
'export'
]
=
{}
...
@@ -187,6 +218,41 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -187,6 +218,41 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
model_dir
=
model_dir
)
model_dir
=
model_dir
)
self
.
assertTrue
(
os
.
path
.
exists
(
export_path
))
self
.
assertTrue
(
os
.
path
.
exists
(
export_path
))
@
combinations
.
generate
(
combinations
.
combine
(
distribution
=
[
strategy_combinations
.
tpu_strategy
,
],
model
=
[
'efficientnet'
,
'resnet'
,
],
mode
=
'eager'
,
dataset
=
'imagenet'
,
dtype
=
'bfloat16'
,
))
def
test_tpu_train
(
self
,
distribution
,
model
,
dataset
,
dtype
):
"""Test train_and_eval and export for Keras classifier models."""
# Some parameters are not defined as flags (e.g. cannot run
# classifier_train.py --batch_size=...) by design, so use
# "--params_override=..." instead
model_dir
=
self
.
get_temp_dir
()
base_flags
=
[
'--data_dir=not_used'
,
'--model_type='
+
model
,
'--dataset='
+
dataset
,
]
train_and_eval_flags
=
base_flags
+
[
get_params_override
(
basic_params_override
(
dtype
)),
'--mode=train_and_eval'
,
]
run
=
functools
.
partial
(
classifier_trainer
.
run
,
strategy_override
=
distribution
)
run_end_to_end
(
main
=
run
,
extra_flags
=
train_and_eval_flags
,
model_dir
=
model_dir
)
@
combinations
.
generate
(
distribution_strategy_combinations
())
@
combinations
.
generate
(
distribution_strategy_combinations
())
def
test_end_to_end_invalid_mode
(
self
,
distribution
,
model
,
dataset
):
def
test_end_to_end_invalid_mode
(
self
,
distribution
,
model
,
dataset
):
"""Test the Keras EfficientNet model with `strategy`."""
"""Test the Keras EfficientNet model with `strategy`."""
...
@@ -239,8 +305,8 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
...
@@ -239,8 +305,8 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
)
)
def
test_get_loss_scale
(
self
,
loss_scale
,
dtype
,
expected
):
def
test_get_loss_scale
(
self
,
loss_scale
,
dtype
,
expected
):
config
=
base_configs
.
ExperimentConfig
(
config
=
base_configs
.
ExperimentConfig
(
model
=
base_configs
.
Model
Config
(
runtime
=
base_configs
.
Runtime
Config
(
loss
=
base_configs
.
LossConfig
(
loss_scale
=
loss_scale
)
)
,
loss_scale
=
loss_scale
),
train_dataset
=
dataset_factory
.
DatasetConfig
(
dtype
=
dtype
))
train_dataset
=
dataset_factory
.
DatasetConfig
(
dtype
=
dtype
))
ls
=
classifier_trainer
.
get_loss_scale
(
config
,
fp16_default
=
128
)
ls
=
classifier_trainer
.
get_loss_scale
(
config
,
fp16_default
=
128
)
self
.
assertEqual
(
ls
,
expected
)
self
.
assertEqual
(
ls
,
expected
)
...
@@ -252,19 +318,23 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
...
@@ -252,19 +318,23 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
def
test_initialize
(
self
,
dtype
):
def
test_initialize
(
self
,
dtype
):
config
=
base_configs
.
ExperimentConfig
(
config
=
base_configs
.
ExperimentConfig
(
runtime
=
base_configs
.
RuntimeConfig
(
runtime
=
base_configs
.
RuntimeConfig
(
enable
_eager
=
False
,
run
_eager
ly
=
False
,
enable_xla
=
False
,
enable_xla
=
False
,
gpu_threads_enabled
=
True
,
per_gpu_thread_count
=
1
,
per_gpu_thread_count
=
1
,
gpu_thread_mode
=
'gpu_private'
,
gpu_thread_mode
=
'gpu_private'
,
num_gpus
=
1
,
num_gpus
=
1
,
dataset_num_private_threads
=
1
,
dataset_num_private_threads
=
1
,
),
),
train_dataset
=
dataset_factory
.
DatasetConfig
(
dtype
=
dtype
),
train_dataset
=
dataset_factory
.
DatasetConfig
(
dtype
=
dtype
),
model
=
base_configs
.
ModelConfig
(
model
=
base_configs
.
ModelConfig
(),
loss
=
base_configs
.
LossConfig
(
loss_scale
=
'dynamic'
)),
)
)
classifier_trainer
.
initialize
(
config
)
class
EmptyClass
:
pass
fake_ds_builder
=
EmptyClass
()
fake_ds_builder
.
dtype
=
dtype
fake_ds_builder
.
config
=
EmptyClass
()
classifier_trainer
.
initialize
(
config
,
fake_ds_builder
)
def
test_resume_from_checkpoint
(
self
):
def
test_resume_from_checkpoint
(
self
):
"""Tests functionality for resuming from checkpoint."""
"""Tests functionality for resuming from checkpoint."""
...
@@ -313,5 +383,4 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
...
@@ -313,5 +383,4 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
tf
.
io
.
gfile
.
rmtree
(
model_dir
)
tf
.
io
.
gfile
.
rmtree
(
model_dir
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
assert
tf
.
version
.
VERSION
.
startswith
(
'2.'
)
tf
.
test
.
main
()
tf
.
test
.
main
()
official/vision/image_classification/configs/base_configs.py
View file @
965cc3ee
...
@@ -58,6 +58,17 @@ class MetricsConfig(base_config.Config):
...
@@ -58,6 +58,17 @@ class MetricsConfig(base_config.Config):
top_5
:
bool
=
None
top_5
:
bool
=
None
@
dataclasses
.
dataclass
class
TimeHistoryConfig
(
base_config
.
Config
):
"""Configuration for the TimeHistory callback.
Attributes:
log_steps: Interval of steps between logging of batch level stats.
"""
log_steps
:
int
=
None
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
TrainConfig
(
base_config
.
Config
):
class
TrainConfig
(
base_config
.
Config
):
"""Configuration for training.
"""Configuration for training.
...
@@ -71,14 +82,18 @@ class TrainConfig(base_config.Config):
...
@@ -71,14 +82,18 @@ class TrainConfig(base_config.Config):
callbacks: An instance of CallbacksConfig.
callbacks: An instance of CallbacksConfig.
metrics: An instance of MetricsConfig.
metrics: An instance of MetricsConfig.
tensorboard: An instance of TensorboardConfig.
tensorboard: An instance of TensorboardConfig.
steps_per_loop: The number of batches to run during each `tf.function`
call during training, which can increase training speed.
"""
"""
resume_checkpoint
:
bool
=
None
resume_checkpoint
:
bool
=
None
epochs
:
int
=
None
epochs
:
int
=
None
steps
:
int
=
None
steps
:
int
=
None
callbacks
:
CallbacksConfig
=
CallbacksConfig
()
callbacks
:
CallbacksConfig
=
CallbacksConfig
()
metrics
:
List
[
str
]
=
None
metrics
:
MetricsConfig
=
None
tensorboard
:
TensorboardConfig
=
TensorboardConfig
()
tensorboard
:
TensorboardConfig
=
TensorboardConfig
()
time_history
:
TimeHistoryConfig
=
TimeHistoryConfig
()
steps_per_loop
:
int
=
None
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
@@ -91,10 +106,12 @@ class EvalConfig(base_config.Config):
...
@@ -91,10 +106,12 @@ class EvalConfig(base_config.Config):
steps: The number of eval steps to run during evaluation. If None, this will
steps: The number of eval steps to run during evaluation. If None, this will
be inferred based on the number of images and batch size. Defaults to
be inferred based on the number of images and batch size. Defaults to
None.
None.
skip_eval: Whether or not to skip evaluation.
"""
"""
epochs_between_evals
:
int
=
None
epochs_between_evals
:
int
=
None
steps
:
int
=
None
steps
:
int
=
None
skip_eval
:
bool
=
False
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
@@ -103,13 +120,11 @@ class LossConfig(base_config.Config):
...
@@ -103,13 +120,11 @@ class LossConfig(base_config.Config):
Attributes:
Attributes:
name: The name of the loss. Defaults to None.
name: The name of the loss. Defaults to None.
loss_scale: The type of loss scale
label_smoothing: Whether or not to apply label smoothing to the loss. This
label_smoothing: Whether or not to apply label smoothing to the loss. This
only applies to 'categorical_cross_entropy'.
only applies to 'categorical_cross_entropy'.
"""
"""
name
:
str
=
None
name
:
str
=
None
loss_scale
:
str
=
None
label_smoothing
:
float
=
None
label_smoothing
:
float
=
None
...
@@ -164,6 +179,7 @@ class LearningRateConfig(base_config.Config):
...
@@ -164,6 +179,7 @@ class LearningRateConfig(base_config.Config):
multipliers: multipliers used in piecewise constant decay with warmup.
multipliers: multipliers used in piecewise constant decay with warmup.
scale_by_batch_size: Scale the learning rate by a fraction of the batch
scale_by_batch_size: Scale the learning rate by a fraction of the batch
size. Set to 0 for no scaling (default).
size. Set to 0 for no scaling (default).
staircase: Apply exponential decay at discrete values instead of continuous.
"""
"""
name
:
str
=
None
name
:
str
=
None
...
@@ -175,6 +191,7 @@ class LearningRateConfig(base_config.Config):
...
@@ -175,6 +191,7 @@ class LearningRateConfig(base_config.Config):
boundaries
:
List
[
int
]
=
None
boundaries
:
List
[
int
]
=
None
multipliers
:
List
[
float
]
=
None
multipliers
:
List
[
float
]
=
None
scale_by_batch_size
:
float
=
0.
scale_by_batch_size
:
float
=
0.
staircase
:
bool
=
None
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
@@ -190,7 +207,7 @@ class ModelConfig(base_config.Config):
...
@@ -190,7 +207,7 @@ class ModelConfig(base_config.Config):
"""
"""
name
:
str
=
None
name
:
str
=
None
model_params
:
Mapping
[
str
,
Any
]
=
None
model_params
:
base_config
.
Config
=
None
num_classes
:
int
=
None
num_classes
:
int
=
None
loss
:
LossConfig
=
None
loss
:
LossConfig
=
None
optimizer
:
OptimizerConfig
=
None
optimizer
:
OptimizerConfig
=
None
...
@@ -216,7 +233,6 @@ class ExperimentConfig(base_config.Config):
...
@@ -216,7 +233,6 @@ class ExperimentConfig(base_config.Config):
runtime
:
RuntimeConfig
=
None
runtime
:
RuntimeConfig
=
None
train_dataset
:
Any
=
None
train_dataset
:
Any
=
None
validation_dataset
:
Any
=
None
validation_dataset
:
Any
=
None
test_dataset
:
Any
=
None
train
:
TrainConfig
=
None
train
:
TrainConfig
=
None
evaluation
:
EvalConfig
=
None
evaluation
:
EvalConfig
=
None
model
:
ModelConfig
=
None
model
:
ModelConfig
=
None
...
...
official/vision/image_classification/configs/configs.py
View file @
965cc3ee
...
@@ -45,8 +45,6 @@ class EfficientNetImageNetConfig(base_configs.ExperimentConfig):
...
@@ -45,8 +45,6 @@ class EfficientNetImageNetConfig(base_configs.ExperimentConfig):
dataset_factory
.
ImageNetConfig
(
split
=
'train'
)
dataset_factory
.
ImageNetConfig
(
split
=
'train'
)
validation_dataset
:
dataset_factory
.
DatasetConfig
=
\
validation_dataset
:
dataset_factory
.
DatasetConfig
=
\
dataset_factory
.
ImageNetConfig
(
split
=
'validation'
)
dataset_factory
.
ImageNetConfig
(
split
=
'validation'
)
test_dataset
:
dataset_factory
.
DatasetConfig
=
\
dataset_factory
.
ImageNetConfig
(
split
=
'validation'
)
train
:
base_configs
.
TrainConfig
=
base_configs
.
TrainConfig
(
train
:
base_configs
.
TrainConfig
=
base_configs
.
TrainConfig
(
resume_checkpoint
=
True
,
resume_checkpoint
=
True
,
epochs
=
500
,
epochs
=
500
,
...
@@ -54,8 +52,10 @@ class EfficientNetImageNetConfig(base_configs.ExperimentConfig):
...
@@ -54,8 +52,10 @@ class EfficientNetImageNetConfig(base_configs.ExperimentConfig):
callbacks
=
base_configs
.
CallbacksConfig
(
enable_checkpoint_and_export
=
True
,
callbacks
=
base_configs
.
CallbacksConfig
(
enable_checkpoint_and_export
=
True
,
enable_tensorboard
=
True
),
enable_tensorboard
=
True
),
metrics
=
[
'accuracy'
,
'top_5'
],
metrics
=
[
'accuracy'
,
'top_5'
],
time_history
=
base_configs
.
TimeHistoryConfig
(
log_steps
=
100
),
tensorboard
=
base_configs
.
TensorboardConfig
(
track_lr
=
True
,
tensorboard
=
base_configs
.
TensorboardConfig
(
track_lr
=
True
,
write_model_weights
=
False
))
write_model_weights
=
False
),
steps_per_loop
=
1
)
evaluation
:
base_configs
.
EvalConfig
=
base_configs
.
EvalConfig
(
evaluation
:
base_configs
.
EvalConfig
=
base_configs
.
EvalConfig
(
epochs_between_evals
=
1
,
epochs_between_evals
=
1
,
steps
=
None
)
steps
=
None
)
...
@@ -78,11 +78,6 @@ class ResNetImagenetConfig(base_configs.ExperimentConfig):
...
@@ -78,11 +78,6 @@ class ResNetImagenetConfig(base_configs.ExperimentConfig):
one_hot
=
False
,
one_hot
=
False
,
mean_subtract
=
True
,
mean_subtract
=
True
,
standardize
=
True
)
standardize
=
True
)
test_dataset
:
dataset_factory
.
DatasetConfig
=
\
dataset_factory
.
ImageNetConfig
(
split
=
'validation'
,
one_hot
=
False
,
mean_subtract
=
True
,
standardize
=
True
)
train
:
base_configs
.
TrainConfig
=
base_configs
.
TrainConfig
(
train
:
base_configs
.
TrainConfig
=
base_configs
.
TrainConfig
(
resume_checkpoint
=
True
,
resume_checkpoint
=
True
,
epochs
=
90
,
epochs
=
90
,
...
@@ -90,8 +85,10 @@ class ResNetImagenetConfig(base_configs.ExperimentConfig):
...
@@ -90,8 +85,10 @@ class ResNetImagenetConfig(base_configs.ExperimentConfig):
callbacks
=
base_configs
.
CallbacksConfig
(
enable_checkpoint_and_export
=
True
,
callbacks
=
base_configs
.
CallbacksConfig
(
enable_checkpoint_and_export
=
True
,
enable_tensorboard
=
True
),
enable_tensorboard
=
True
),
metrics
=
[
'accuracy'
,
'top_5'
],
metrics
=
[
'accuracy'
,
'top_5'
],
time_history
=
base_configs
.
TimeHistoryConfig
(
log_steps
=
100
),
tensorboard
=
base_configs
.
TensorboardConfig
(
track_lr
=
True
,
tensorboard
=
base_configs
.
TensorboardConfig
(
track_lr
=
True
,
write_model_weights
=
False
))
write_model_weights
=
False
),
steps_per_loop
=
1
)
evaluation
:
base_configs
.
EvalConfig
=
base_configs
.
EvalConfig
(
evaluation
:
base_configs
.
EvalConfig
=
base_configs
.
EvalConfig
(
epochs_between_evals
=
1
,
epochs_between_evals
=
1
,
steps
=
None
)
steps
=
None
)
...
...
official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b0-gpu.yaml
View file @
965cc3ee
...
@@ -3,8 +3,6 @@
...
@@ -3,8 +3,6 @@
# Reaches ~76.1% within 350 epochs.
# Reaches ~76.1% within 350 epochs.
# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
runtime
:
runtime
:
model_dir
:
null
mode
:
'
train_and_eval'
distribution_strategy
:
'
mirrored'
distribution_strategy
:
'
mirrored'
num_gpus
:
1
num_gpus
:
1
train_dataset
:
train_dataset
:
...
@@ -36,10 +34,13 @@ model:
...
@@ -36,10 +34,13 @@ model:
num_classes
:
1000
num_classes
:
1000
batch_norm
:
'
default'
batch_norm
:
'
default'
dtype
:
'
float32'
dtype
:
'
float32'
activation
:
'
swish'
optimizer
:
optimizer
:
name
:
'
rmsprop'
name
:
'
rmsprop'
momentum
:
0.9
momentum
:
0.9
decay
:
0.9
decay
:
0.9
moving_average_decay
:
0.0
lookahead
:
false
learning_rate
:
learning_rate
:
name
:
'
exponential'
name
:
'
exponential'
loss
:
loss
:
...
...
official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b0-tpu.yaml
View file @
965cc3ee
...
@@ -3,8 +3,6 @@
...
@@ -3,8 +3,6 @@
# Reaches ~76.1% within 350 epochs.
# Reaches ~76.1% within 350 epochs.
# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
runtime
:
runtime
:
model_dir
:
null
mode
:
'
train_and_eval'
distribution_strategy
:
'
tpu'
distribution_strategy
:
'
tpu'
train_dataset
:
train_dataset
:
name
:
'
imagenet2012'
name
:
'
imagenet2012'
...
@@ -35,11 +33,12 @@ model:
...
@@ -35,11 +33,12 @@ model:
num_classes
:
1000
num_classes
:
1000
batch_norm
:
'
tpu'
batch_norm
:
'
tpu'
dtype
:
'
bfloat16'
dtype
:
'
bfloat16'
activation
:
'
swish'
optimizer
:
optimizer
:
name
:
'
rmsprop'
name
:
'
rmsprop'
momentum
:
0.9
momentum
:
0.9
decay
:
0.9
decay
:
0.9
moving_average_decay
:
0.
moving_average_decay
:
0.
0
lookahead
:
false
lookahead
:
false
learning_rate
:
learning_rate
:
name
:
'
exponential'
name
:
'
exponential'
...
...
official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b1-gpu.yaml
View file @
965cc3ee
# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
runtime
:
runtime
:
model_dir
:
null
mode
:
'
train_and_eval'
distribution_strategy
:
'
mirrored'
distribution_strategy
:
'
mirrored'
num_gpus
:
1
num_gpus
:
1
train_dataset
:
train_dataset
:
...
@@ -12,6 +10,7 @@ train_dataset:
...
@@ -12,6 +10,7 @@ train_dataset:
num_classes
:
1000
num_classes
:
1000
num_examples
:
1281167
num_examples
:
1281167
batch_size
:
32
batch_size
:
32
use_per_replica_batch_size
:
True
dtype
:
'
float32'
dtype
:
'
float32'
validation_dataset
:
validation_dataset
:
name
:
'
imagenet2012'
name
:
'
imagenet2012'
...
@@ -21,6 +20,7 @@ validation_dataset:
...
@@ -21,6 +20,7 @@ validation_dataset:
num_classes
:
1000
num_classes
:
1000
num_examples
:
50000
num_examples
:
50000
batch_size
:
32
batch_size
:
32
use_per_replica_batch_size
:
True
dtype
:
'
float32'
dtype
:
'
float32'
model
:
model
:
model_params
:
model_params
:
...
@@ -29,10 +29,13 @@ model:
...
@@ -29,10 +29,13 @@ model:
num_classes
:
1000
num_classes
:
1000
batch_norm
:
'
default'
batch_norm
:
'
default'
dtype
:
'
float32'
dtype
:
'
float32'
activation
:
'
swish'
optimizer
:
optimizer
:
name
:
'
rmsprop'
name
:
'
rmsprop'
momentum
:
0.9
momentum
:
0.9
decay
:
0.9
decay
:
0.9
moving_average_decay
:
0.0
lookahead
:
false
learning_rate
:
learning_rate
:
name
:
'
exponential'
name
:
'
exponential'
loss
:
loss
:
...
...
official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b1-tpu.yaml
View file @
965cc3ee
...
@@ -2,8 +2,6 @@
...
@@ -2,8 +2,6 @@
# Takes ~3 minutes, 15 seconds per epoch for v3-32.
# Takes ~3 minutes, 15 seconds per epoch for v3-32.
# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
runtime
:
runtime
:
model_dir
:
null
mode
:
'
train_and_eval'
distribution_strategy
:
'
tpu'
distribution_strategy
:
'
tpu'
train_dataset
:
train_dataset
:
name
:
'
imagenet2012'
name
:
'
imagenet2012'
...
@@ -34,10 +32,13 @@ model:
...
@@ -34,10 +32,13 @@ model:
num_classes
:
1000
num_classes
:
1000
batch_norm
:
'
tpu'
batch_norm
:
'
tpu'
dtype
:
'
bfloat16'
dtype
:
'
bfloat16'
activation
:
'
swish'
optimizer
:
optimizer
:
name
:
'
rmsprop'
name
:
'
rmsprop'
momentum
:
0.9
momentum
:
0.9
decay
:
0.9
decay
:
0.9
moving_average_decay
:
0.0
lookahead
:
false
learning_rate
:
learning_rate
:
name
:
'
exponential'
name
:
'
exponential'
loss
:
loss
:
...
...
official/vision/image_classification/configs/examples/resnet/imagenet/gpu.yaml
View file @
965cc3ee
# Training configuration for ResNet trained on ImageNet on GPUs.
# Training configuration for ResNet trained on ImageNet on GPUs.
# Takes ~3 minutes, 15 seconds per epoch for 8 V100s.
# Reaches > 76.1% within 90 epochs.
# Reaches ~76.1% within 90 epochs.
# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
runtime
:
runtime
:
model_dir
:
null
mode
:
'
train_and_eval'
distribution_strategy
:
'
mirrored'
distribution_strategy
:
'
mirrored'
num_gpus
:
1
num_gpus
:
1
train_dataset
:
train_dataset
:
name
:
'
imagenet2012'
name
:
'
imagenet2012'
data_dir
:
null
data_dir
:
null
builder
:
'
recor
ds'
builder
:
'
tf
ds'
split
:
'
train'
split
:
'
train'
image_size
:
224
image_size
:
224
num_classes
:
1000
num_classes
:
1000
...
@@ -23,7 +20,7 @@ train_dataset:
...
@@ -23,7 +20,7 @@ train_dataset:
validation_dataset
:
validation_dataset
:
name
:
'
imagenet2012'
name
:
'
imagenet2012'
data_dir
:
null
data_dir
:
null
builder
:
'
recor
ds'
builder
:
'
tf
ds'
split
:
'
validation'
split
:
'
validation'
image_size
:
224
image_size
:
224
num_classes
:
1000
num_classes
:
1000
...
@@ -34,7 +31,7 @@ validation_dataset:
...
@@ -34,7 +31,7 @@ validation_dataset:
mean_subtract
:
True
mean_subtract
:
True
standardize
:
True
standardize
:
True
model
:
model
:
model_
name
:
'
resnet'
name
:
'
resnet'
model_params
:
model_params
:
rescale_inputs
:
False
rescale_inputs
:
False
optimizer
:
optimizer
:
...
...
official/vision/image_classification/configs/examples/resnet/imagenet/tpu.yaml
View file @
965cc3ee
# Training configuration for ResNet trained on ImageNet on TPUs.
# Training configuration for ResNet trained on ImageNet on TPUs.
# Takes ~
2
minutes,
4
3 seconds per epoch for a v3-32.
# Takes ~
4
minutes, 3
0 seconds
seconds per epoch for a v3-32.
# Reaches
~
76.1% within 90 epochs.
# Reaches
>
76.1% within 90 epochs.
# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
runtime
:
runtime
:
model_dir
:
null
mode
:
'
train_and_eval'
distribution_strategy
:
'
tpu'
distribution_strategy
:
'
tpu'
train_dataset
:
train_dataset
:
name
:
'
imagenet2012'
name
:
'
imagenet2012'
data_dir
:
null
data_dir
:
null
builder
:
'
recor
ds'
builder
:
'
tf
ds'
split
:
'
train'
split
:
'
train'
one_hot
:
False
one_hot
:
False
image_size
:
224
image_size
:
224
...
@@ -23,7 +21,7 @@ train_dataset:
...
@@ -23,7 +21,7 @@ train_dataset:
validation_dataset
:
validation_dataset
:
name
:
'
imagenet2012'
name
:
'
imagenet2012'
data_dir
:
null
data_dir
:
null
builder
:
'
recor
ds'
builder
:
'
tf
ds'
split
:
'
validation'
split
:
'
validation'
one_hot
:
False
one_hot
:
False
image_size
:
224
image_size
:
224
...
@@ -35,7 +33,7 @@ validation_dataset:
...
@@ -35,7 +33,7 @@ validation_dataset:
standardize
:
True
standardize
:
True
dtype
:
'
bfloat16'
dtype
:
'
bfloat16'
model
:
model
:
model_
name
:
'
resnet'
name
:
'
resnet'
model_params
:
model_params
:
rescale_inputs
:
False
rescale_inputs
:
False
optimizer
:
optimizer
:
...
...
official/vision/image_classification/dataset_factory.py
View file @
965cc3ee
...
@@ -23,7 +23,7 @@ import os
...
@@ -23,7 +23,7 @@ import os
from
typing
import
Any
,
List
,
Optional
,
Tuple
,
Mapping
,
Union
from
typing
import
Any
,
List
,
Optional
,
Tuple
,
Mapping
,
Union
from
absl
import
logging
from
absl
import
logging
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
import
tensorflow
.compat.v2
as
tf
import
tensorflow
as
tf
import
tensorflow_datasets
as
tfds
import
tensorflow_datasets
as
tfds
from
official.modeling.hyperparams
import
base_config
from
official.modeling.hyperparams
import
base_config
...
@@ -84,11 +84,10 @@ class DatasetConfig(base_config.Config):
...
@@ -84,11 +84,10 @@ class DatasetConfig(base_config.Config):
use_per_replica_batch_size: Whether to scale the batch size based on
use_per_replica_batch_size: Whether to scale the batch size based on
available resources. If set to `True`, the dataset builder will return
available resources. If set to `True`, the dataset builder will return
batch_size multiplied by `num_devices`, the number of device replicas
batch_size multiplied by `num_devices`, the number of device replicas
(e.g., the number of GPUs or TPU cores).
(e.g., the number of GPUs or TPU cores). This setting should be `True` if
the strategy argument is passed to `build()` and `num_devices > 1`.
num_devices: The number of replica devices to use. This should be set by
num_devices: The number of replica devices to use. This should be set by
`strategy.num_replicas_in_sync` when using a distribution strategy.
`strategy.num_replicas_in_sync` when using a distribution strategy.
data_format: The data format of the images. Should be 'channels_last' or
'channels_first'.
dtype: The desired dtype of the dataset. This will be set during
dtype: The desired dtype of the dataset. This will be set during
preprocessing.
preprocessing.
one_hot: Whether to apply one hot encoding. Set to `True` to be able to use
one_hot: Whether to apply one hot encoding. Set to `True` to be able to use
...
@@ -118,9 +117,8 @@ class DatasetConfig(base_config.Config):
...
@@ -118,9 +117,8 @@ class DatasetConfig(base_config.Config):
num_channels
:
Union
[
int
,
str
]
=
'infer'
num_channels
:
Union
[
int
,
str
]
=
'infer'
num_examples
:
Union
[
int
,
str
]
=
'infer'
num_examples
:
Union
[
int
,
str
]
=
'infer'
batch_size
:
int
=
128
batch_size
:
int
=
128
use_per_replica_batch_size
:
bool
=
Fals
e
use_per_replica_batch_size
:
bool
=
Tru
e
num_devices
:
int
=
1
num_devices
:
int
=
1
data_format
:
str
=
'channels_last'
dtype
:
str
=
'float32'
dtype
:
str
=
'float32'
one_hot
:
bool
=
True
one_hot
:
bool
=
True
augmenter
:
AugmentConfig
=
AugmentConfig
()
augmenter
:
AugmentConfig
=
AugmentConfig
()
...
@@ -188,14 +186,22 @@ class DatasetBuilder:
...
@@ -188,14 +186,22 @@ class DatasetBuilder:
def
batch_size
(
self
)
->
int
:
def
batch_size
(
self
)
->
int
:
"""The batch size, multiplied by the number of replicas (if configured)."""
"""The batch size, multiplied by the number of replicas (if configured)."""
if
self
.
config
.
use_per_replica_batch_size
:
if
self
.
config
.
use_per_replica_batch_size
:
return
self
.
global_
batch_size
return
self
.
config
.
batch_size
*
self
.
config
.
num_devices
else
:
else
:
return
self
.
config
.
batch_size
return
self
.
config
.
batch_size
@
property
@
property
def
global_batch_size
(
self
):
def
global_batch_size
(
self
):
"""The global batch size across all replicas."""
"""The global batch size across all replicas."""
return
self
.
config
.
batch_size
*
self
.
config
.
num_devices
return
self
.
batch_size
@
property
def
local_batch_size
(
self
):
"""The base unscaled batch size."""
if
self
.
config
.
use_per_replica_batch_size
:
return
self
.
config
.
batch_size
else
:
return
self
.
config
.
batch_size
//
self
.
config
.
num_devices
@
property
@
property
def
num_steps
(
self
)
->
int
:
def
num_steps
(
self
)
->
int
:
...
@@ -203,6 +209,30 @@ class DatasetBuilder:
...
@@ -203,6 +209,30 @@ class DatasetBuilder:
# Always divide by the global batch size to get the correct # of steps
# Always divide by the global batch size to get the correct # of steps
return
self
.
num_examples
//
self
.
global_batch_size
return
self
.
num_examples
//
self
.
global_batch_size
@
property
def
dtype
(
self
)
->
tf
.
dtypes
.
DType
:
"""Converts the config's dtype string to a tf dtype.
Returns:
A mapping from string representation of a dtype to the `tf.dtypes.DType`.
Raises:
ValueError if the config's dtype is not supported.
"""
dtype_map
=
{
'float32'
:
tf
.
float32
,
'bfloat16'
:
tf
.
bfloat16
,
'float16'
:
tf
.
float16
,
'fp32'
:
tf
.
float32
,
'bf16'
:
tf
.
bfloat16
,
}
try
:
return
dtype_map
[
self
.
config
.
dtype
]
except
:
raise
ValueError
(
'Invalid DType provided. Supported types: {}'
.
format
(
dtype_map
.
keys
()))
@
property
@
property
def
image_size
(
self
)
->
int
:
def
image_size
(
self
)
->
int
:
"""The size of each image (can be inferred from the dataset)."""
"""The size of each image (can be inferred from the dataset)."""
...
@@ -243,19 +273,42 @@ class DatasetBuilder:
...
@@ -243,19 +273,42 @@ class DatasetBuilder:
self
.
builder_info
=
tfds
.
builder
(
self
.
config
.
name
).
info
self
.
builder_info
=
tfds
.
builder
(
self
.
config
.
name
).
info
return
self
.
builder_info
return
self
.
builder_info
def
build
(
self
,
input_context
:
tf
.
distribute
.
InputContext
=
None
def
build
(
self
,
strategy
:
tf
.
distribute
.
Strategy
=
None
)
->
tf
.
data
.
Dataset
:
)
->
tf
.
data
.
Dataset
:
"""Construct a dataset end-to-end and return it using an optional strategy.
Args:
strategy: a strategy that, if passed, will distribute the dataset
according to that strategy. If passed and `num_devices > 1`,
`use_per_replica_batch_size` must be set to `True`.
Returns:
A TensorFlow dataset outputting batched images and labels.
"""
if
strategy
:
if
strategy
.
num_replicas_in_sync
!=
self
.
config
.
num_devices
:
logging
.
warn
(
'Passed a strategy with %d devices, but expected'
'%d devices.'
,
strategy
.
num_replicas_in_sync
,
self
.
config
.
num_devices
)
dataset
=
strategy
.
experimental_distribute_datasets_from_function
(
self
.
_build
)
else
:
dataset
=
self
.
_build
()
return
dataset
def
_build
(
self
,
input_context
:
tf
.
distribute
.
InputContext
=
None
)
->
tf
.
data
.
Dataset
:
"""Construct a dataset end-to-end and return it.
"""Construct a dataset end-to-end and return it.
Args:
Args:
input_context: An optional context provided by `tf.distribute` for
input_context: An optional context provided by `tf.distribute` for
cross-replica training. This isn't necessary if using Keras
cross-replica training.
compile/fit.
Returns:
Returns:
A TensorFlow dataset outputting batched images and labels.
A TensorFlow dataset outputting batched images and labels.
"""
"""
builders
=
{
builders
=
{
'tfds'
:
self
.
load_tfds
,
'tfds'
:
self
.
load_tfds
,
'records'
:
self
.
load_records
,
'records'
:
self
.
load_records
,
...
@@ -326,7 +379,7 @@ class DatasetBuilder:
...
@@ -326,7 +379,7 @@ class DatasetBuilder:
def
generate_data
(
_
):
def
generate_data
(
_
):
image
=
tf
.
zeros
([
self
.
image_size
,
self
.
image_size
,
self
.
num_channels
],
image
=
tf
.
zeros
([
self
.
image_size
,
self
.
image_size
,
self
.
num_channels
],
dtype
=
self
.
config
.
dtype
)
dtype
=
self
.
dtype
)
label
=
tf
.
zeros
([
1
],
dtype
=
tf
.
int32
)
label
=
tf
.
zeros
([
1
],
dtype
=
tf
.
int32
)
return
image
,
label
return
image
,
label
...
@@ -345,8 +398,8 @@ class DatasetBuilder:
...
@@ -345,8 +398,8 @@ class DatasetBuilder:
Args:
Args:
dataset: A `tf.data.Dataset` that loads raw files.
dataset: A `tf.data.Dataset` that loads raw files.
input_context: An optional context provided by `tf.distribute` for
input_context: An optional context provided by `tf.distribute` for
cross-replica training.
This isn't necessary if using Kera
s
cross-replica training.
If set with more than one replica, thi
s
compile/fit
.
function assumes `use_per_replica_batch_size=True`
.
Returns:
Returns:
A TensorFlow dataset outputting batched images and labels.
A TensorFlow dataset outputting batched images and labels.
...
@@ -366,8 +419,6 @@ class DatasetBuilder:
...
@@ -366,8 +419,6 @@ class DatasetBuilder:
cycle_length
=
16
,
cycle_length
=
16
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
dataset
=
dataset
.
prefetch
(
self
.
global_batch_size
)
if
self
.
config
.
cache
:
if
self
.
config
.
cache
:
dataset
=
dataset
.
cache
()
dataset
=
dataset
.
cache
()
...
@@ -383,13 +434,25 @@ class DatasetBuilder:
...
@@ -383,13 +434,25 @@ class DatasetBuilder:
dataset
=
dataset
.
map
(
preprocess
,
dataset
=
dataset
.
map
(
preprocess
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
dataset
=
dataset
.
batch
(
self
.
batch_size
,
drop_remainder
=
self
.
is_training
)
if
input_context
and
self
.
config
.
num_devices
>
1
:
if
not
self
.
config
.
use_per_replica_batch_size
:
# Note: we could do image normalization here, but we defer it to the model
raise
ValueError
(
# which can perform it much faster on a GPU/TPU
'The builder does not support a global batch size with more than '
# TODO(dankondratyuk): if we fix prefetching, we can do it here
'one replica. Got {} replicas. Please set a '
'`per_replica_batch_size` and enable '
'`use_per_replica_batch_size=True`.'
.
format
(
self
.
config
.
num_devices
))
# The batch size of the dataset will be multiplied by the number of
# replicas automatically when strategy.distribute_datasets_from_function
# is called, so we use local batch size here.
dataset
=
dataset
.
batch
(
self
.
local_batch_size
,
drop_remainder
=
self
.
is_training
)
else
:
dataset
=
dataset
.
batch
(
self
.
global_batch_size
,
drop_remainder
=
self
.
is_training
)
if
self
.
is_training
and
self
.
config
.
deterministic_train
is
not
None
:
if
self
.
is_training
:
options
=
tf
.
data
.
Options
()
options
=
tf
.
data
.
Options
()
options
.
experimental_deterministic
=
self
.
config
.
deterministic_train
options
.
experimental_deterministic
=
self
.
config
.
deterministic_train
options
.
experimental_slack
=
self
.
config
.
use_slack
options
.
experimental_slack
=
self
.
config
.
use_slack
...
@@ -400,9 +463,7 @@ class DatasetBuilder:
...
@@ -400,9 +463,7 @@ class DatasetBuilder:
dataset
=
dataset
.
with_options
(
options
)
dataset
=
dataset
.
with_options
(
options
)
# Prefetch overlaps in-feed with training
# Prefetch overlaps in-feed with training
# Note: autotune here is not recommended, as this can lead to memory leaks.
dataset
=
dataset
.
prefetch
(
tf
.
data
.
experimental
.
AUTOTUNE
)
# Instead, use a constant prefetch size like the the number of devices.
dataset
=
dataset
.
prefetch
(
self
.
config
.
num_devices
)
return
dataset
return
dataset
...
@@ -451,7 +512,7 @@ class DatasetBuilder:
...
@@ -451,7 +512,7 @@ class DatasetBuilder:
image_size
=
self
.
image_size
,
image_size
=
self
.
image_size
,
mean_subtract
=
self
.
config
.
mean_subtract
,
mean_subtract
=
self
.
config
.
mean_subtract
,
standardize
=
self
.
config
.
standardize
,
standardize
=
self
.
config
.
standardize
,
dtype
=
self
.
config
.
dtype
,
dtype
=
self
.
dtype
,
augmenter
=
self
.
augmenter
)
augmenter
=
self
.
augmenter
)
else
:
else
:
image
=
preprocessing
.
preprocess_for_eval
(
image
=
preprocessing
.
preprocess_for_eval
(
...
@@ -460,7 +521,7 @@ class DatasetBuilder:
...
@@ -460,7 +521,7 @@ class DatasetBuilder:
num_channels
=
self
.
num_channels
,
num_channels
=
self
.
num_channels
,
mean_subtract
=
self
.
config
.
mean_subtract
,
mean_subtract
=
self
.
config
.
mean_subtract
,
standardize
=
self
.
config
.
standardize
,
standardize
=
self
.
config
.
standardize
,
dtype
=
self
.
config
.
dtype
)
dtype
=
self
.
dtype
)
label
=
tf
.
cast
(
label
,
tf
.
int32
)
label
=
tf
.
cast
(
label
,
tf
.
int32
)
if
self
.
config
.
one_hot
:
if
self
.
config
.
one_hot
:
...
...
official/vision/image_classification/efficientnet/common_modules.py
View file @
965cc3ee
...
@@ -19,15 +19,14 @@ from __future__ import division
...
@@ -19,15 +19,14 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow.compat.v1
as
tf1
import
tensorflow.compat.v1
as
tf1
import
tensorflow.compat.v2
as
tf
from
typing
import
Text
,
Optional
from
typing
import
Text
,
Optional
from
tensorflow.python.tpu
import
tpu_function
from
tensorflow.python.tpu
import
tpu_function
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'
Text
'
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'
Vision
'
)
class
TpuBatchNormalization
(
tf
.
keras
.
layers
.
BatchNormalization
):
class
TpuBatchNormalization
(
tf
.
keras
.
layers
.
BatchNormalization
):
"""Cross replica batch normalization."""
"""Cross replica batch normalization."""
...
@@ -98,3 +97,21 @@ def count_params(model, trainable_only=True):
...
@@ -98,3 +97,21 @@ def count_params(model, trainable_only=True):
else
:
else
:
return
int
(
np
.
sum
([
tf
.
keras
.
backend
.
count_params
(
p
)
return
int
(
np
.
sum
([
tf
.
keras
.
backend
.
count_params
(
p
)
for
p
in
model
.
trainable_weights
]))
for
p
in
model
.
trainable_weights
]))
def
load_weights
(
model
:
tf
.
keras
.
Model
,
model_weights_path
:
Text
,
weights_format
:
Text
=
'saved_model'
):
"""Load model weights from the given file path.
Args:
model: the model to load weights into
model_weights_path: the path of the model weights
weights_format: the model weights format. One of 'saved_model', 'h5',
or 'checkpoint'.
"""
if
weights_format
==
'saved_model'
:
loaded_model
=
tf
.
keras
.
models
.
load_model
(
model_weights_path
)
model
.
set_weights
(
loaded_model
.
get_weights
())
else
:
model
.
load_weights
(
model_weights_path
)
official/vision/image_classification/efficientnet/efficientnet_config.py
View file @
965cc3ee
...
@@ -22,6 +22,7 @@ from typing import Any, Mapping
...
@@ -22,6 +22,7 @@ from typing import Any, Mapping
import
dataclasses
import
dataclasses
from
official.modeling.hyperparams
import
base_config
from
official.vision.image_classification.configs
import
base_configs
from
official.vision.image_classification.configs
import
base_configs
...
@@ -43,23 +44,24 @@ class EfficientNetModelConfig(base_configs.ModelConfig):
...
@@ -43,23 +44,24 @@ class EfficientNetModelConfig(base_configs.ModelConfig):
configuration.
configuration.
learning_rate: The configuration for learning rate. Defaults to an
learning_rate: The configuration for learning rate. Defaults to an
exponential configuration.
exponential configuration.
"""
"""
name
:
str
=
'EfficientNet'
name
:
str
=
'EfficientNet'
num_classes
:
int
=
1000
num_classes
:
int
=
1000
model_params
:
Mapping
[
str
,
Any
]
=
dataclasses
.
field
(
default_factory
=
lambda
:
{
model_params
:
base_config
.
Config
=
dataclasses
.
field
(
'model_name'
:
'efficientnet-b0'
,
default_factory
=
lambda
:
{
'model_weights_path'
:
''
,
'model_name'
:
'efficientnet-b0'
,
'copy_to_local'
:
False
,
'model_weights_path'
:
''
,
'overrides'
:
{
'weights_format'
:
'saved_model'
,
'batch_norm'
:
'default'
,
'overrides'
:
{
'rescale_input'
:
True
,
'batch_norm'
:
'default'
,
'num_classes'
:
1000
,
'rescale_input'
:
True
,
}
'num_classes'
:
1000
,
})
'activation'
:
'swish'
,
'dtype'
:
'float32'
,
}
})
loss
:
base_configs
.
LossConfig
=
base_configs
.
LossConfig
(
loss
:
base_configs
.
LossConfig
=
base_configs
.
LossConfig
(
name
=
'categorical_crossentropy'
,
name
=
'categorical_crossentropy'
,
label_smoothing
=
0.1
)
label_smoothing
=
0.1
)
optimizer
:
base_configs
.
OptimizerConfig
=
base_configs
.
OptimizerConfig
(
optimizer
:
base_configs
.
OptimizerConfig
=
base_configs
.
OptimizerConfig
(
name
=
'rmsprop'
,
name
=
'rmsprop'
,
decay
=
0.9
,
decay
=
0.9
,
...
@@ -72,4 +74,5 @@ class EfficientNetModelConfig(base_configs.ModelConfig):
...
@@ -72,4 +74,5 @@ class EfficientNetModelConfig(base_configs.ModelConfig):
decay_epochs
=
2.4
,
decay_epochs
=
2.4
,
decay_rate
=
0.97
,
decay_rate
=
0.97
,
warmup_epochs
=
5
,
warmup_epochs
=
5
,
scale_by_batch_size
=
1.
/
128.
)
scale_by_batch_size
=
1.
/
128.
,
staircase
=
True
)
official/vision/image_classification/efficientnet/efficientnet_model.py
View file @
965cc3ee
...
@@ -30,7 +30,7 @@ from typing import Any, Dict, Optional, Text, Tuple
...
@@ -30,7 +30,7 @@ from typing import Any, Dict, Optional, Text, Tuple
from
absl
import
logging
from
absl
import
logging
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
import
tensorflow
.compat.v2
as
tf
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.modeling
import
tf_utils
from
official.modeling.hyperparams
import
base_config
from
official.modeling.hyperparams
import
base_config
...
@@ -104,6 +104,8 @@ MODEL_CONFIGS = {
...
@@ -104,6 +104,8 @@ MODEL_CONFIGS = {
'efficientnet-b5'
:
ModelConfig
.
from_args
(
1.6
,
2.2
,
456
,
0.4
),
'efficientnet-b5'
:
ModelConfig
.
from_args
(
1.6
,
2.2
,
456
,
0.4
),
'efficientnet-b6'
:
ModelConfig
.
from_args
(
1.8
,
2.6
,
528
,
0.5
),
'efficientnet-b6'
:
ModelConfig
.
from_args
(
1.8
,
2.6
,
528
,
0.5
),
'efficientnet-b7'
:
ModelConfig
.
from_args
(
2.0
,
3.1
,
600
,
0.5
),
'efficientnet-b7'
:
ModelConfig
.
from_args
(
2.0
,
3.1
,
600
,
0.5
),
'efficientnet-b8'
:
ModelConfig
.
from_args
(
2.2
,
3.6
,
672
,
0.5
),
'efficientnet-l2'
:
ModelConfig
.
from_args
(
4.3
,
5.3
,
800
,
0.5
),
}
}
CONV_KERNEL_INITIALIZER
=
{
CONV_KERNEL_INITIALIZER
=
{
...
@@ -166,7 +168,7 @@ def conv2d_block(inputs: tf.Tensor,
...
@@ -166,7 +168,7 @@ def conv2d_block(inputs: tf.Tensor,
batch_norm
=
common_modules
.
get_batch_norm
(
config
.
batch_norm
)
batch_norm
=
common_modules
.
get_batch_norm
(
config
.
batch_norm
)
bn_momentum
=
config
.
bn_momentum
bn_momentum
=
config
.
bn_momentum
bn_epsilon
=
config
.
bn_epsilon
bn_epsilon
=
config
.
bn_epsilon
data_format
=
config
.
data_format
data_format
=
tf
.
keras
.
backend
.
image_
data_format
()
weight_decay
=
config
.
weight_decay
weight_decay
=
config
.
weight_decay
name
=
name
or
''
name
=
name
or
''
...
@@ -223,7 +225,7 @@ def mb_conv_block(inputs: tf.Tensor,
...
@@ -223,7 +225,7 @@ def mb_conv_block(inputs: tf.Tensor,
use_se
=
config
.
use_se
use_se
=
config
.
use_se
activation
=
tf_utils
.
get_activation
(
config
.
activation
)
activation
=
tf_utils
.
get_activation
(
config
.
activation
)
drop_connect_rate
=
config
.
drop_connect_rate
drop_connect_rate
=
config
.
drop_connect_rate
data_format
=
config
.
data_format
data_format
=
tf
.
keras
.
backend
.
image_
data_format
()
use_depthwise
=
block
.
conv_type
!=
'no_depthwise'
use_depthwise
=
block
.
conv_type
!=
'no_depthwise'
prefix
=
prefix
or
''
prefix
=
prefix
or
''
...
@@ -346,12 +348,14 @@ def efficientnet(image_input: tf.keras.layers.Input,
...
@@ -346,12 +348,14 @@ def efficientnet(image_input: tf.keras.layers.Input,
num_classes
=
config
.
num_classes
num_classes
=
config
.
num_classes
input_channels
=
config
.
input_channels
input_channels
=
config
.
input_channels
rescale_input
=
config
.
rescale_input
rescale_input
=
config
.
rescale_input
data_format
=
config
.
data_format
data_format
=
tf
.
keras
.
backend
.
image_
data_format
()
dtype
=
config
.
dtype
dtype
=
config
.
dtype
weight_decay
=
config
.
weight_decay
weight_decay
=
config
.
weight_decay
x
=
image_input
x
=
image_input
if
data_format
==
'channels_first'
:
# Happens on GPU/TPU if available.
x
=
tf
.
keras
.
layers
.
Permute
((
3
,
1
,
2
))(
x
)
if
rescale_input
:
if
rescale_input
:
x
=
preprocessing
.
normalize_images
(
x
,
x
=
preprocessing
.
normalize_images
(
x
,
num_channels
=
input_channels
,
num_channels
=
input_channels
,
...
@@ -463,7 +467,7 @@ class EfficientNet(tf.keras.Model):
...
@@ -463,7 +467,7 @@ class EfficientNet(tf.keras.Model):
def
from_name
(
cls
,
def
from_name
(
cls
,
model_name
:
Text
,
model_name
:
Text
,
model_weights_path
:
Text
=
None
,
model_weights_path
:
Text
=
None
,
copy_to_local
:
bool
=
False
,
weights_format
:
Text
=
'saved_model'
,
overrides
:
Dict
[
Text
,
Any
]
=
None
):
overrides
:
Dict
[
Text
,
Any
]
=
None
):
"""Construct an EfficientNet model from a predefined model name.
"""Construct an EfficientNet model from a predefined model name.
...
@@ -472,7 +476,8 @@ class EfficientNet(tf.keras.Model):
...
@@ -472,7 +476,8 @@ class EfficientNet(tf.keras.Model):
Args:
Args:
model_name: the predefined model name
model_name: the predefined model name
model_weights_path: the path to the weights (h5 file or saved model dir)
model_weights_path: the path to the weights (h5 file or saved model dir)
copy_to_local: copy the weights to a local tmp dir
weights_format: the model weights format. One of 'saved_model', 'h5',
or 'checkpoint'.
overrides: (optional) a dict containing keys that can override config
overrides: (optional) a dict containing keys that can override config
Returns:
Returns:
...
@@ -492,12 +497,8 @@ class EfficientNet(tf.keras.Model):
...
@@ -492,12 +497,8 @@ class EfficientNet(tf.keras.Model):
model
=
cls
(
config
=
config
,
overrides
=
overrides
)
model
=
cls
(
config
=
config
,
overrides
=
overrides
)
if
model_weights_path
:
if
model_weights_path
:
if
copy_to_local
:
common_modules
.
load_weights
(
model
,
tmp_file
=
os
.
path
.
join
(
'/tmp'
,
model_name
+
'.h5'
)
model_weights_path
,
model_weights_file
=
os
.
path
.
join
(
model_weights_path
,
'model.h5'
)
weights_format
=
weights_format
)
tf
.
io
.
gfile
.
copy
(
model_weights_file
,
tmp_file
,
overwrite
=
True
)
model_weights_path
=
tmp_file
model
.
load_weights
(
model_weights_path
)
return
model
return
model
official/vision/image_classification/efficientnet/tfhub_export.py
0 → 100644
View file @
965cc3ee
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A script to export TF-Hub SavedModel."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
os
from
absl
import
app
from
absl
import
flags
import
tensorflow
as
tf
from
official.vision.image_classification.efficientnet
import
efficientnet_model
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_string
(
"model_name"
,
None
,
"EfficientNet model name."
)
flags
.
DEFINE_string
(
"model_path"
,
None
,
"File path to TF model checkpoint."
)
flags
.
DEFINE_string
(
"export_path"
,
None
,
"TF-Hub SavedModel destination path to export."
)
def
export_tfhub
(
model_path
,
hub_destination
,
model_name
):
"""Restores a tf.keras.Model and saves for TF-Hub."""
model
=
efficientnet_model
.
EfficientNet
.
from_name
(
model_name
)
ckpt
=
tf
.
train
.
Checkpoint
(
model
=
model
)
ckpt
.
restore
(
model_path
).
assert_existing_objects_matched
()
image_input
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
None
,
None
,
3
),
name
=
"image_input"
,
dtype
=
tf
.
float32
)
x
=
image_input
*
255.0
ouputs
=
model
(
x
)
hub_model
=
tf
.
keras
.
Model
(
image_input
,
ouputs
)
# Exports a SavedModel.
hub_model
.
save
(
os
.
path
.
join
(
hub_destination
,
"classification"
),
include_optimizer
=
False
)
feature_vector_output
=
hub_model
.
get_layer
(
name
=
"efficientnet"
).
get_layer
(
name
=
"top_pool"
).
get_output_at
(
0
)
hub_model2
=
tf
.
keras
.
Model
(
model
.
inputs
,
feature_vector_output
)
# Exports a SavedModel.
hub_model2
.
save
(
os
.
path
.
join
(
hub_destination
,
"feature-vector"
),
include_optimizer
=
False
)
def
main
(
argv
):
if
len
(
argv
)
>
1
:
raise
app
.
UsageError
(
"Too many command-line arguments."
)
export_tfhub
(
FLAGS
.
model_path
,
FLAGS
.
export_path
,
FLAGS
.
model_name
)
if
__name__
==
"__main__"
:
app
.
run
(
main
)
official/vision/image_classification/learning_rate.py
View file @
965cc3ee
...
@@ -20,7 +20,7 @@ from __future__ import print_function
...
@@ -20,7 +20,7 @@ from __future__ import print_function
from
typing
import
Any
,
List
,
Mapping
from
typing
import
Any
,
List
,
Mapping
import
tensorflow
.compat.v2
as
tf
import
tensorflow
as
tf
BASE_LEARNING_RATE
=
0.1
BASE_LEARNING_RATE
=
0.1
...
...
official/vision/image_classification/learning_rate_test.py
View file @
965cc3ee
...
@@ -18,7 +18,7 @@ from __future__ import absolute_import
...
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
tensorflow
.compat.v2
as
tf
import
tensorflow
as
tf
from
official.vision.image_classification
import
learning_rate
from
official.vision.image_classification
import
learning_rate
...
@@ -86,5 +86,4 @@ class LearningRateTests(tf.test.TestCase):
...
@@ -86,5 +86,4 @@ class LearningRateTests(tf.test.TestCase):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
assert
tf
.
version
.
VERSION
.
startswith
(
'2.'
)
tf
.
test
.
main
()
tf
.
test
.
main
()
Prev
1
…
4
5
6
7
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment