Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
61a61902
Commit
61a61902
authored
Apr 20, 2020
by
A. Unique TensorFlower
Browse files
Use `strategy.distribute_datasets_from_function` in the classifier trainer.
PiperOrigin-RevId: 307483983
parent
ba772461
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
63 additions
and
21 deletions
+63
-21
official/vision/image_classification/classifier_trainer.py
official/vision/image_classification/classifier_trainer.py
+2
-2
official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b1-gpu.yaml
...s/examples/efficientnet/imagenet/efficientnet-b1-gpu.yaml
+2
-0
official/vision/image_classification/dataset_factory.py
official/vision/image_classification/dataset_factory.py
+59
-19
No files found.
official/vision/image_classification/classifier_trainer.py
View file @
61a61902
...
...
@@ -32,7 +32,6 @@ import tensorflow as tf
from
official.modeling
import
performance
from
official.modeling.hyperparams
import
params_dict
from
official.utils
import
hyperparams_flags
from
official.utils.logs
import
logger
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
keras_utils
from
official.vision.image_classification
import
callbacks
as
custom_callbacks
...
...
@@ -316,7 +315,8 @@ def train_and_eval(
one_hot
=
label_smoothing
and
label_smoothing
>
0
builders
=
_get_dataset_builders
(
params
,
strategy
,
one_hot
)
datasets
=
[
builder
.
build
()
if
builder
else
None
for
builder
in
builders
]
datasets
=
[
builder
.
build
(
strategy
)
if
builder
else
None
for
builder
in
builders
]
# Unpack datasets and builders based on train/val/test splits
train_builder
,
validation_builder
=
builders
# pylint: disable=unbalanced-tuple-unpacking
...
...
official/vision/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b1-gpu.yaml
View file @
61a61902
...
...
@@ -10,6 +10,7 @@ train_dataset:
num_classes
:
1000
num_examples
:
1281167
batch_size
:
32
use_per_replica_batch_size
:
True
dtype
:
'
float32'
validation_dataset
:
name
:
'
imagenet2012'
...
...
@@ -19,6 +20,7 @@ validation_dataset:
num_classes
:
1000
num_examples
:
50000
batch_size
:
32
use_per_replica_batch_size
:
True
dtype
:
'
float32'
model
:
model_params
:
...
...
official/vision/image_classification/dataset_factory.py
View file @
61a61902
...
...
@@ -84,7 +84,8 @@ class DatasetConfig(base_config.Config):
use_per_replica_batch_size: Whether to scale the batch size based on
available resources. If set to `True`, the dataset builder will return
batch_size multiplied by `num_devices`, the number of device replicas
(e.g., the number of GPUs or TPU cores).
(e.g., the number of GPUs or TPU cores). This setting should be `True` if
the strategy argument is passed to `build()` and `num_devices > 1`.
num_devices: The number of replica devices to use. This should be set by
`strategy.num_replicas_in_sync` when using a distribution strategy.
dtype: The desired dtype of the dataset. This will be set during
...
...
@@ -194,6 +195,14 @@ class DatasetBuilder:
"""The global batch size across all replicas."""
return
self
.
batch_size
@
property
def
local_batch_size
(
self
):
"""The base unscaled batch size."""
if
self
.
config
.
use_per_replica_batch_size
:
return
self
.
config
.
batch_size
else
:
return
self
.
config
.
batch_size
//
self
.
config
.
num_devices
@
property
def
num_steps
(
self
)
->
int
:
"""The number of steps (batches) to exhaust this dataset."""
...
...
@@ -264,19 +273,42 @@ class DatasetBuilder:
self
.
builder_info
=
tfds
.
builder
(
self
.
config
.
name
).
info
return
self
.
builder_info
def
build
(
self
,
input_context
:
tf
.
distribute
.
InputContext
=
None
)
->
tf
.
data
.
Dataset
:
def
build
(
self
,
strategy
:
tf
.
distribute
.
Strategy
=
None
)
->
tf
.
data
.
Dataset
:
"""Construct a dataset end-to-end and return it using an optional strategy.
Args:
strategy: a strategy that, if passed, will distribute the dataset
according to that strategy. If passed and `num_devices > 1`,
`use_per_replica_batch_size` must be set to `True`.
Returns:
A TensorFlow dataset outputting batched images and labels.
"""
if
strategy
:
if
strategy
.
num_replicas_in_sync
!=
self
.
config
.
num_devices
:
logging
.
warn
(
'Passed a strategy with %d devices, but expected'
'%d devices.'
,
strategy
.
num_replicas_in_sync
,
self
.
config
.
num_devices
)
dataset
=
strategy
.
experimental_distribute_datasets_from_function
(
self
.
_build
)
else
:
dataset
=
self
.
_build
()
return
dataset
def
_build
(
self
,
input_context
:
tf
.
distribute
.
InputContext
=
None
)
->
tf
.
data
.
Dataset
:
"""Construct a dataset end-to-end and return it.
Args:
input_context: An optional context provided by `tf.distribute` for
cross-replica training. This isn't necessary if using Keras
compile/fit.
cross-replica training.
Returns:
A TensorFlow dataset outputting batched images and labels.
"""
builders
=
{
'tfds'
:
self
.
load_tfds
,
'records'
:
self
.
load_records
,
...
...
@@ -366,8 +398,8 @@ class DatasetBuilder:
Args:
dataset: A `tf.data.Dataset` that loads raw files.
input_context: An optional context provided by `tf.distribute` for
cross-replica training.
This isn't necessary if using Kera
s
compile/fit
.
cross-replica training.
If set with more than one replica, thi
s
function assumes `use_per_replica_batch_size=True`
.
Returns:
A TensorFlow dataset outputting batched images and labels.
...
...
@@ -387,8 +419,6 @@ class DatasetBuilder:
cycle_length
=
16
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
dataset
=
dataset
.
prefetch
(
self
.
global_batch_size
)
if
self
.
config
.
cache
:
dataset
=
dataset
.
cache
()
...
...
@@ -404,13 +434,25 @@ class DatasetBuilder:
dataset
=
dataset
.
map
(
preprocess
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
dataset
=
dataset
.
batch
(
self
.
batch_size
,
drop_remainder
=
self
.
is_training
)
# Note: we could do image normalization here, but we defer it to the model
# which can perform it much faster on a GPU/TPU
# TODO(dankondratyuk): if we fix prefetching, we can do it here
if
input_context
and
self
.
config
.
num_devices
>
1
:
if
not
self
.
config
.
use_per_replica_batch_size
:
raise
ValueError
(
'The builder does not support a global batch size with more than '
'one replica. Got {} replicas. Please set a '
'`per_replica_batch_size` and enable '
'`use_per_replica_batch_size=True`.'
.
format
(
self
.
config
.
num_devices
))
# The batch size of the dataset will be multiplied by the number of
# replicas automatically when strategy.distribute_datasets_from_function
# is called, so we use local batch size here.
dataset
=
dataset
.
batch
(
self
.
local_batch_size
,
drop_remainder
=
self
.
is_training
)
else
:
dataset
=
dataset
.
batch
(
self
.
global_batch_size
,
drop_remainder
=
self
.
is_training
)
if
self
.
is_training
and
self
.
config
.
deterministic_train
is
not
None
:
if
self
.
is_training
:
options
=
tf
.
data
.
Options
()
options
.
experimental_deterministic
=
self
.
config
.
deterministic_train
options
.
experimental_slack
=
self
.
config
.
use_slack
...
...
@@ -421,9 +463,7 @@ class DatasetBuilder:
dataset
=
dataset
.
with_options
(
options
)
# Prefetch overlaps in-feed with training
# Note: autotune here is not recommended, as this can lead to memory leaks.
# Instead, use a constant prefetch size like the the number of devices.
dataset
=
dataset
.
prefetch
(
self
.
config
.
num_devices
)
dataset
=
dataset
.
prefetch
(
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment