Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
add2845a
Commit
add2845a
authored
Aug 20, 2017
by
Toby Boyd
Browse files
Style cleanup
parent
a7531875
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
105 additions
and
119 deletions
+105
-119
tutorials/image/cifar10_estimator/cifar10.py
tutorials/image/cifar10_estimator/cifar10.py
+2
-2
tutorials/image/cifar10_estimator/cifar10_main.py
tutorials/image/cifar10_estimator/cifar10_main.py
+103
-117
No files found.
tutorials/image/cifar10_estimator/cifar10.py
View file @
add2845a
...
@@ -74,8 +74,8 @@ class Cifar10DataSet(object):
...
@@ -74,8 +74,8 @@ class Cifar10DataSet(object):
dataset
=
tf
.
contrib
.
data
.
TFRecordDataset
(
filenames
).
repeat
()
dataset
=
tf
.
contrib
.
data
.
TFRecordDataset
(
filenames
).
repeat
()
# Parse records.
# Parse records.
dataset
=
dataset
.
map
(
self
.
parser
,
num_threads
=
batch_size
,
dataset
=
dataset
.
map
(
output_buffer_size
=
2
*
batch_size
)
self
.
parser
,
num_threads
=
batch_size
,
output_buffer_size
=
2
*
batch_size
)
# Potentially shuffle records.
# Potentially shuffle records.
if
self
.
subset
==
'train'
:
if
self
.
subset
==
'train'
:
...
...
tutorials/image/cifar10_estimator/cifar10_main.py
View file @
add2845a
...
@@ -29,25 +29,23 @@ from __future__ import division
...
@@ -29,25 +29,23 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
argparse
import
argparse
import
collections
import
functools
import
functools
import
itertools
import
itertools
import
os
import
os
import
six
import
numpy
as
np
from
six.moves
import
xrange
# pylint: disable=redefined-builtin
import
tensorflow
as
tf
import
cifar10
import
cifar10
import
cifar10_model
import
cifar10_model
import
cifar10_utils
import
cifar10_utils
import
numpy
as
np
import
six
from
six.moves
import
xrange
# pylint: disable=redefined-builtin
import
tensorflow
as
tf
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
def
get_model_fn
(
num_gpus
,
variable_strategy
,
num_workers
):
def
get_model_fn
(
num_gpus
,
variable_strategy
,
data_format
,
num_workers
):
"""Returns a function that will build the resnet model."""
def
_resnet_model_fn
(
features
,
labels
,
mode
,
params
):
def
_resnet_model_fn
(
features
,
labels
,
mode
,
params
):
"""Resnet model body.
"""Resnet model body.
...
@@ -85,28 +83,20 @@ def get_model_fn(num_gpus, variable_strategy, num_workers):
...
@@ -85,28 +83,20 @@ def get_model_fn(num_gpus, variable_strategy, num_workers):
for
i
in
range
(
num_devices
):
for
i
in
range
(
num_devices
):
worker_device
=
'/{}:{}'
.
format
(
device_type
,
i
)
worker_device
=
'/{}:{}'
.
format
(
device_type
,
i
)
if
variable_strategy
==
'CPU'
:
if
variable_strategy
==
'CPU'
:
device_setter
=
cifar10_utils
.
local_device_setter
(
device_setter
=
cifar10_utils
.
local_device_setter
(
worker_device
=
worker_device
)
worker_device
=
worker_device
)
elif
variable_strategy
==
'GPU'
:
elif
variable_strategy
==
'GPU'
:
device_setter
=
cifar10_utils
.
local_device_setter
(
device_setter
=
cifar10_utils
.
local_device_setter
(
ps_device_type
=
'gpu'
,
ps_device_type
=
'gpu'
,
worker_device
=
worker_device
,
worker_device
=
worker_device
,
ps_strategy
=
tf
.
contrib
.
training
.
GreedyLoadBalancingStrategy
(
ps_strategy
=
tf
.
contrib
.
training
.
GreedyLoadBalancingStrategy
(
num_gpus
,
num_gpus
,
tf
.
contrib
.
training
.
byte_size_load_fn
))
tf
.
contrib
.
training
.
byte_size_load_fn
)
)
with
tf
.
variable_scope
(
'resnet'
,
reuse
=
bool
(
i
!=
0
)):
with
tf
.
variable_scope
(
'resnet'
,
reuse
=
bool
(
i
!=
0
)):
with
tf
.
name_scope
(
'tower_%d'
%
i
)
as
name_scope
:
with
tf
.
name_scope
(
'tower_%d'
%
i
)
as
name_scope
:
with
tf
.
device
(
device_setter
):
with
tf
.
device
(
device_setter
):
loss
,
gradvars
,
preds
=
_tower_fn
(
loss
,
gradvars
,
preds
=
_tower_fn
(
is_training
,
is_training
,
weight_decay
,
tower_features
[
i
],
tower_labels
[
i
],
weight_decay
,
data_format
,
params
[
'num_layers'
],
params
[
'batch_norm_decay'
],
tower_features
[
i
],
tower_labels
[
i
],
(
device_type
==
'cpu'
),
params
[
'num_layers'
],
params
[
'batch_norm_decay'
],
params
[
'batch_norm_epsilon'
])
params
[
'batch_norm_epsilon'
])
tower_losses
.
append
(
loss
)
tower_losses
.
append
(
loss
)
tower_gradvars
.
append
(
gradvars
)
tower_gradvars
.
append
(
gradvars
)
...
@@ -137,7 +127,6 @@ def get_model_fn(num_gpus, variable_strategy, num_workers):
...
@@ -137,7 +127,6 @@ def get_model_fn(num_gpus, variable_strategy, num_workers):
avg_grad
=
tf
.
multiply
(
tf
.
add_n
(
grads
),
1.
/
len
(
grads
))
avg_grad
=
tf
.
multiply
(
tf
.
add_n
(
grads
),
1.
/
len
(
grads
))
gradvars
.
append
((
avg_grad
,
var
))
gradvars
.
append
((
avg_grad
,
var
))
# Device that runs the ops to apply global gradient updates.
# Device that runs the ops to apply global gradient updates.
consolidation_device
=
'/gpu:0'
if
variable_strategy
==
'GPU'
else
'/cpu:0'
consolidation_device
=
'/gpu:0'
if
variable_strategy
==
'GPU'
else
'/cpu:0'
with
tf
.
device
(
consolidation_device
):
with
tf
.
device
(
consolidation_device
):
...
@@ -163,8 +152,7 @@ def get_model_fn(num_gpus, variable_strategy, num_workers):
...
@@ -163,8 +152,7 @@ def get_model_fn(num_gpus, variable_strategy, num_workers):
chief_hooks
=
[]
chief_hooks
=
[]
if
params
[
'sync'
]:
if
params
[
'sync'
]:
optimizer
=
tf
.
train
.
SyncReplicasOptimizer
(
optimizer
=
tf
.
train
.
SyncReplicasOptimizer
(
optimizer
,
optimizer
,
replicas_to_aggregate
=
num_workers
)
replicas_to_aggregate
=
num_workers
)
sync_replicas_hook
=
optimizer
.
make_session_run_hook
(
True
)
sync_replicas_hook
=
optimizer
.
make_session_run_hook
(
True
)
chief_hooks
.
append
(
sync_replicas_hook
)
chief_hooks
.
append
(
sync_replicas_hook
)
...
@@ -184,7 +172,8 @@ def get_model_fn(num_gpus, variable_strategy, num_workers):
...
@@ -184,7 +172,8 @@ def get_model_fn(num_gpus, variable_strategy, num_workers):
}
}
stacked_labels
=
tf
.
concat
(
labels
,
axis
=
0
)
stacked_labels
=
tf
.
concat
(
labels
,
axis
=
0
)
metrics
=
{
metrics
=
{
'accuracy'
:
tf
.
metrics
.
accuracy
(
stacked_labels
,
predictions
[
'classes'
])
'accuracy'
:
tf
.
metrics
.
accuracy
(
stacked_labels
,
predictions
[
'classes'
])
}
}
loss
=
tf
.
reduce_mean
(
tower_losses
,
name
=
'loss'
)
loss
=
tf
.
reduce_mean
(
tower_losses
,
name
=
'loss'
)
...
@@ -195,35 +184,35 @@ def get_model_fn(num_gpus, variable_strategy, num_workers):
...
@@ -195,35 +184,35 @@ def get_model_fn(num_gpus, variable_strategy, num_workers):
train_op
=
train_op
,
train_op
=
train_op
,
training_chief_hooks
=
chief_hooks
,
training_chief_hooks
=
chief_hooks
,
eval_metric_ops
=
metrics
)
eval_metric_ops
=
metrics
)
return
_resnet_model_fn
return
_resnet_model_fn
def
_tower_fn
(
is_training
,
def
_tower_fn
(
is_training
,
weight_decay
,
feature
,
label
,
data_format
,
weight_decay
,
num_layers
,
batch_norm_decay
,
batch_norm_epsilon
):
feature
,
"""Build computation tower (Resnet).
label
,
is_cpu
,
num_layers
,
batch_norm_decay
,
batch_norm_epsilon
):
"""Build computation tower for each device (CPU or GPU).
Args:
Args:
is_training: true if is training graph.
is_training: true if is training graph.
weight_decay: weight regularization strength, a float.
weight_decay: weight regularization strength, a float.
feature: a Tensor.
feature: a Tensor.
label: a Tensor.
label: a Tensor.
tower_losses: a list to be appended with current tower's loss.
data_format: channels_last (NHWC) or channels_first (NCHW).
tower_gradvars: a list to be appended with current tower's gradients.
num_layers: number of layers, an int.
tower_preds: a list to be appended with current tower's predictions.
batch_norm_decay: decay for batch normalization, a float.
is_cpu: true if build tower on CPU.
batch_norm_epsilon: epsilon for batch normalization, a float.
Returns:
A tuple with the loss for the tower, the gradients and parameters, and
predictions.
"""
"""
data_format
=
'channels_last'
if
is_cpu
else
'channels_first'
model
=
cifar10_model
.
ResNetCifar10
(
model
=
cifar10_model
.
ResNetCifar10
(
num_layers
,
num_layers
,
batch_norm_decay
=
batch_norm_decay
,
batch_norm_decay
=
batch_norm_decay
,
batch_norm_epsilon
=
batch_norm_epsilon
,
batch_norm_epsilon
=
batch_norm_epsilon
,
is_training
=
is_training
,
data_format
=
data_format
)
is_training
=
is_training
,
data_format
=
data_format
)
logits
=
model
.
forward_pass
(
feature
,
input_data_format
=
'channels_last'
)
logits
=
model
.
forward_pass
(
feature
,
input_data_format
=
'channels_last'
)
tower_pred
=
{
tower_pred
=
{
'classes'
:
tf
.
argmax
(
input
=
logits
,
axis
=
1
),
'classes'
:
tf
.
argmax
(
input
=
logits
,
axis
=
1
),
...
@@ -243,13 +232,20 @@ def _tower_fn(is_training,
...
@@ -243,13 +232,20 @@ def _tower_fn(is_training,
return
tower_loss
,
zip
(
tower_grad
,
model_params
),
tower_pred
return
tower_loss
,
zip
(
tower_grad
,
model_params
),
tower_pred
def
input_fn
(
data_dir
,
subset
,
num_shards
,
batch_size
,
def
input_fn
(
data_dir
,
subset
,
num_shards
,
batch_size
,
use_distortion_for_training
=
True
):
use_distortion_for_training
=
True
):
"""Create input graph for model.
"""Create input graph for model.
Args:
Args:
data_dir: Directory where TFRecords representing the dataset are located.
subset: one of 'train', 'validate' and 'eval'.
subset: one of 'train', 'validate' and 'eval'.
num_shards: num of towers participating in data-parallel training.
num_shards: num of towers participating in data-parallel training.
batch_size: total batch size for training to be divided by the number of
shards.
use_distortion_for_training: True to use distortions.
Returns:
Returns:
two lists of tensors for features and labels, each of num_shards length.
two lists of tensors for features and labels, each of num_shards length.
"""
"""
...
@@ -279,7 +275,10 @@ def input_fn(data_dir, subset, num_shards, batch_size,
...
@@ -279,7 +275,10 @@ def input_fn(data_dir, subset, num_shards, batch_size,
# create experiment
# create experiment
def
get_experiment_fn
(
data_dir
,
num_gpus
,
is_gpu_ps
,
def
get_experiment_fn
(
data_dir
,
num_gpus
,
variable_strategy
,
data_format
,
use_distortion_for_training
=
True
):
use_distortion_for_training
=
True
):
"""Returns an Experiment function.
"""Returns an Experiment function.
...
@@ -292,7 +291,9 @@ def get_experiment_fn(data_dir, num_gpus, is_gpu_ps,
...
@@ -292,7 +291,9 @@ def get_experiment_fn(data_dir, num_gpus, is_gpu_ps,
Args:
Args:
data_dir: str. Location of the data for input_fns.
data_dir: str. Location of the data for input_fns.
num_gpus: int. Number of GPUs on each worker.
num_gpus: int. Number of GPUs on each worker.
is_gpu_ps: bool. If true, average gradients on GPUs.
variable_strategy: String. CPU to use CPU as the parameter server
and GPU to use the GPUs as the parameter server.
data_format: String. channels_first or channels_last.
use_distortion_for_training: bool. See cifar10.Cifar10DataSet.
use_distortion_for_training: bool. See cifar10.Cifar10DataSet.
Returns:
Returns:
A function (tf.estimator.RunConfig, tf.contrib.training.HParams) ->
A function (tf.estimator.RunConfig, tf.contrib.training.HParams) ->
...
@@ -302,6 +303,7 @@ def get_experiment_fn(data_dir, num_gpus, is_gpu_ps,
...
@@ -302,6 +303,7 @@ def get_experiment_fn(data_dir, num_gpus, is_gpu_ps,
methods on Experiment (train, evaluate) based on information
methods on Experiment (train, evaluate) based on information
about the current runner in `run_config`.
about the current runner in `run_config`.
"""
"""
def
_experiment_fn
(
run_config
,
hparams
):
def
_experiment_fn
(
run_config
,
hparams
):
"""Returns an Experiment."""
"""Returns an Experiment."""
# Create estimator.
# Create estimator.
...
@@ -311,40 +313,37 @@ def get_experiment_fn(data_dir, num_gpus, is_gpu_ps,
...
@@ -311,40 +313,37 @@ def get_experiment_fn(data_dir, num_gpus, is_gpu_ps,
subset
=
'train'
,
subset
=
'train'
,
num_shards
=
num_gpus
,
num_shards
=
num_gpus
,
batch_size
=
hparams
.
train_batch_size
,
batch_size
=
hparams
.
train_batch_size
,
use_distortion_for_training
=
use_distortion_for_training
use_distortion_for_training
=
use_distortion_for_training
)
)
eval_input_fn
=
functools
.
partial
(
eval_input_fn
=
functools
.
partial
(
input_fn
,
input_fn
,
data_dir
,
data_dir
,
subset
=
'eval'
,
subset
=
'eval'
,
batch_size
=
hparams
.
eval_batch_size
,
batch_size
=
hparams
.
eval_batch_size
,
num_shards
=
num_gpus
num_shards
=
num_gpus
)
)
num_eval_examples
=
cifar10
.
Cifar10DataSet
.
num_examples_per_epoch
(
'eval'
)
num_eval_examples
=
cifar10
.
Cifar10DataSet
.
num_examples_per_epoch
(
'eval'
)
if
num_eval_examples
%
hparams
.
eval_batch_size
!=
0
:
if
num_eval_examples
%
hparams
.
eval_batch_size
!=
0
:
raise
ValueError
(
'validation set size must be multiple of eval_batch_size'
)
raise
ValueError
(
'validation set size must be multiple of eval_batch_size'
)
train_steps
=
hparams
.
train_steps
train_steps
=
hparams
.
train_steps
eval_steps
=
num_eval_examples
//
hparams
.
eval_batch_size
eval_steps
=
num_eval_examples
//
hparams
.
eval_batch_size
examples_sec_hook
=
cifar10_utils
.
ExamplesPerSecondHook
(
examples_sec_hook
=
cifar10_utils
.
ExamplesPerSecondHook
(
hparams
.
train_batch_size
,
every_n_steps
=
10
)
hparams
.
train_batch_size
,
every_n_steps
=
10
)
tensors_to_log
=
{
'learning_rate'
:
'learning_rate'
,
tensors_to_log
=
{
'learning_rate'
:
'learning_rate'
,
'loss'
:
'loss'
}
'loss'
:
'loss'
}
logging_hook
=
tf
.
train
.
LoggingTensorHook
(
logging_hook
=
tf
.
train
.
LoggingTensorHook
(
tensors
=
tensors_to_log
,
every_n_iter
=
100
)
tensors
=
tensors_to_log
,
every_n_iter
=
100
)
hooks
=
[
logging_hook
,
examples_sec_hook
]
hooks
=
[
logging_hook
,
examples_sec_hook
]
classifier
=
tf
.
estimator
.
Estimator
(
classifier
=
tf
.
estimator
.
Estimator
(
model_fn
=
get_model_fn
(
model_fn
=
get_model_fn
(
num_gpus
,
variable_strategy
,
data_format
,
num_gpus
,
is_gpu_ps
,
run_config
.
num_worker_replicas
or
1
),
run_config
.
num_worker_replicas
or
1
),
config
=
run_config
,
config
=
run_config
,
params
=
vars
(
hparams
)
params
=
vars
(
hparams
))
)
# Create experiment.
# Create experiment.
experiment
=
tf
.
contrib
.
learn
.
Experiment
(
experiment
=
tf
.
contrib
.
learn
.
Experiment
(
...
@@ -356,43 +355,40 @@ def get_experiment_fn(data_dir, num_gpus, is_gpu_ps,
...
@@ -356,43 +355,40 @@ def get_experiment_fn(data_dir, num_gpus, is_gpu_ps,
# Adding hooks to be used by the estimator on training modes
# Adding hooks to be used by the estimator on training modes
experiment
.
extend_train_hooks
(
hooks
)
experiment
.
extend_train_hooks
(
hooks
)
return
experiment
return
experiment
return
_experiment_fn
return
_experiment_fn
def
main
(
job_dir
,
def
main
(
job_dir
,
data_dir
,
num_gpus
,
variable_strategy
,
data_format
,
data_dir
,
use_distortion_for_training
,
log_device_placement
,
num_intra_threads
,
num_gpus
,
variable_strategy
,
use_distortion_for_training
,
log_device_placement
,
num_intra_threads
,
**
hparams
):
**
hparams
):
# The env variable is on deprecation path, default is set to off.
# The env variable is on deprecation path, default is set to off.
os
.
environ
[
'TF_SYNC_ON_FINISH'
]
=
'0'
os
.
environ
[
'TF_SYNC_ON_FINISH'
]
=
'0'
os
.
environ
[
'TF_ENABLE_WINOGRAD_NONFUSED'
]
=
'1'
# channels first (NCHW) is normally optimal on GPU and channels last (NHWC)
# on CPU. The exception is Intel MKL on CPU which is optimal with
# channels_last.
if
not
data_format
:
if
num_gpus
==
0
:
data_format
=
'channels_last'
else
:
data_format
=
'channels_first'
# Session configuration.
# Session configuration.
sess_config
=
tf
.
ConfigProto
(
sess_config
=
tf
.
ConfigProto
(
allow_soft_placement
=
True
,
allow_soft_placement
=
True
,
log_device_placement
=
log_device_placement
,
log_device_placement
=
log_device_placement
,
intra_op_parallelism_threads
=
num_intra_threads
,
intra_op_parallelism_threads
=
num_intra_threads
,
gpu_options
=
tf
.
GPUOptions
(
gpu_options
=
tf
.
GPUOptions
(
force_gpu_compatible
=
True
))
force_gpu_compatible
=
True
)
)
config
=
cifar10_utils
.
RunConfig
(
config
=
cifar10_utils
.
RunConfig
(
session_config
=
sess_config
,
session_config
=
sess_config
,
model_dir
=
job_dir
)
model_dir
=
job_dir
)
tf
.
contrib
.
learn
.
learn_runner
.
run
(
tf
.
contrib
.
learn
.
learn_runner
.
run
(
get_experiment_fn
(
get_experiment_fn
(
data_dir
,
num_gpus
,
variable_strategy
,
data_format
,
data_dir
,
use_distortion_for_training
),
num_gpus
,
variable_strategy
,
use_distortion_for_training
),
run_config
=
config
,
run_config
=
config
,
hparams
=
tf
.
contrib
.
training
.
HParams
(
**
hparams
)
hparams
=
tf
.
contrib
.
training
.
HParams
(
**
hparams
))
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
@@ -401,63 +397,53 @@ if __name__ == '__main__':
...
@@ -401,63 +397,53 @@ if __name__ == '__main__':
'--data-dir'
,
'--data-dir'
,
type
=
str
,
type
=
str
,
required
=
True
,
required
=
True
,
help
=
'The directory where the CIFAR-10 input data is stored.'
help
=
'The directory where the CIFAR-10 input data is stored.'
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'--job-dir'
,
'--job-dir'
,
type
=
str
,
type
=
str
,
required
=
True
,
required
=
True
,
help
=
'The directory where the model will be stored.'
help
=
'The directory where the model will be stored.'
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'--variable-strategy'
,
'--variable-strategy'
,
choices
=
[
'CPU'
,
'GPU'
],
choices
=
[
'CPU'
,
'GPU'
],
type
=
str
,
type
=
str
,
default
=
'CPU'
,
default
=
'CPU'
,
help
=
'Where to locate variable operations'
help
=
'Where to locate variable operations'
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'--num-gpus'
,
'--num-gpus'
,
type
=
int
,
type
=
int
,
default
=
1
,
default
=
1
,
help
=
'The number of gpus used. Uses only CPU if set to 0.'
help
=
'The number of gpus used. Uses only CPU if set to 0.'
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'--num-layers'
,
'--num-layers'
,
type
=
int
,
type
=
int
,
default
=
44
,
default
=
44
,
help
=
'The number of layers of the model.'
help
=
'The number of layers of the model.'
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'--train-steps'
,
'--train-steps'
,
type
=
int
,
type
=
int
,
default
=
80000
,
default
=
80000
,
help
=
'The number of steps to use for training.'
help
=
'The number of steps to use for training.'
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'--train-batch-size'
,
'--train-batch-size'
,
type
=
int
,
type
=
int
,
default
=
128
,
default
=
128
,
help
=
'Batch size for training.'
help
=
'Batch size for training.'
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'--eval-batch-size'
,
'--eval-batch-size'
,
type
=
int
,
type
=
int
,
default
=
100
,
default
=
100
,
help
=
'Batch size for validation.'
help
=
'Batch size for validation.'
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'--momentum'
,
'--momentum'
,
type
=
float
,
type
=
float
,
default
=
0.9
,
default
=
0.9
,
help
=
'Momentum for MomentumOptimizer.'
help
=
'Momentum for MomentumOptimizer.'
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'--weight-decay'
,
'--weight-decay'
,
type
=
float
,
type
=
float
,
default
=
2e-4
,
default
=
2e-4
,
help
=
'Weight decay for convolutions.'
help
=
'Weight decay for convolutions.'
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'--learning-rate'
,
'--learning-rate'
,
type
=
float
,
type
=
float
,
...
@@ -466,22 +452,19 @@ if __name__ == '__main__':
...
@@ -466,22 +452,19 @@ if __name__ == '__main__':
This is the inital learning rate value. The learning rate will decrease
This is the inital learning rate value. The learning rate will decrease
during training. For more details check the model_fn implementation in
during training. For more details check the model_fn implementation in
this file.
\
this file.
\
"""
"""
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'--use-distortion-for-training'
,
'--use-distortion-for-training'
,
type
=
bool
,
type
=
bool
,
default
=
True
,
default
=
True
,
help
=
'If doing image distortion for training.'
help
=
'If doing image distortion for training.'
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'--sync'
,
'--sync'
,
action
=
'store_true'
,
action
=
'store_true'
,
default
=
False
,
default
=
False
,
help
=
"""
\
help
=
"""
\
If present when running in a distributed environment will run on sync mode.
\
If present when running in a distributed environment will run on sync mode.
\
"""
"""
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'--num-intra-threads'
,
'--num-intra-threads'
,
type
=
int
,
type
=
int
,
...
@@ -492,8 +475,7 @@ if __name__ == '__main__':
...
@@ -492,8 +475,7 @@ if __name__ == '__main__':
example CPU only handles the input pipeline and gradient aggregation
example CPU only handles the input pipeline and gradient aggregation
(when --is-cpu-ps). Ops that could potentially benefit from intra-op
(when --is-cpu-ps). Ops that could potentially benefit from intra-op
parallelism are scheduled to run on GPUs.
\
parallelism are scheduled to run on GPUs.
\
"""
"""
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'--num-inter-threads'
,
'--num-inter-threads'
,
type
=
int
,
type
=
int
,
...
@@ -501,26 +483,30 @@ if __name__ == '__main__':
...
@@ -501,26 +483,30 @@ if __name__ == '__main__':
help
=
"""
\
help
=
"""
\
Number of threads to use for inter-op parallelism. If set to 0, the
Number of threads to use for inter-op parallelism. If set to 0, the
system will pick an appropriate number.
\
system will pick an appropriate number.
\
"""
"""
)
)
parser
.
add_argument
(
'--data-format'
,
type
=
str
,
default
=
None
,
help
=
"""
\
If not set, the data format best for the training device is used.
Allowed values: channels_first (NCHW) channels_last (NHWC).
\
"""
)
parser
.
add_argument
(
parser
.
add_argument
(
'--log-device-placement'
,
'--log-device-placement'
,
action
=
'store_true'
,
action
=
'store_true'
,
default
=
False
,
default
=
False
,
help
=
'Whether to log device placement.'
help
=
'Whether to log device placement.'
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'--batch-norm-decay'
,
'--batch-norm-decay'
,
type
=
float
,
type
=
float
,
default
=
0.997
,
default
=
0.997
,
help
=
'Decay for batch norm.'
help
=
'Decay for batch norm.'
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'--batch-norm-epsilon'
,
'--batch-norm-epsilon'
,
type
=
float
,
type
=
float
,
default
=
1e-5
,
default
=
1e-5
,
help
=
'Epsilon for batch norm.'
help
=
'Epsilon for batch norm.'
)
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
args
.
num_gpus
<
0
:
if
args
.
num_gpus
<
0
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment