Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
7e9e15ad
Commit
7e9e15ad
authored
Aug 29, 2017
by
Toby Boyd
Committed by
GitHub
Aug 29, 2017
Browse files
Merge pull request #2056 from tfboyd/cifar_mkl
Added data_format flag to support MKL and other interesting tests
parents
3bf85a4e
90fbe70e
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
109 additions
and
128 deletions
+109
-128
tutorials/image/cifar10_estimator/cifar10.py
tutorials/image/cifar10_estimator/cifar10.py
+2
-2
tutorials/image/cifar10_estimator/cifar10_main.py
tutorials/image/cifar10_estimator/cifar10_main.py
+107
-126
No files found.
tutorials/image/cifar10_estimator/cifar10.py
View file @
7e9e15ad
...
@@ -74,8 +74,8 @@ class Cifar10DataSet(object):
...
@@ -74,8 +74,8 @@ class Cifar10DataSet(object):
dataset
=
tf
.
contrib
.
data
.
TFRecordDataset
(
filenames
).
repeat
()
dataset
=
tf
.
contrib
.
data
.
TFRecordDataset
(
filenames
).
repeat
()
# Parse records.
# Parse records.
dataset
=
dataset
.
map
(
self
.
parser
,
num_threads
=
batch_size
,
dataset
=
dataset
.
map
(
output_buffer_size
=
2
*
batch_size
)
self
.
parser
,
num_threads
=
batch_size
,
output_buffer_size
=
2
*
batch_size
)
# Potentially shuffle records.
# Potentially shuffle records.
if
self
.
subset
==
'train'
:
if
self
.
subset
==
'train'
:
...
...
tutorials/image/cifar10_estimator/cifar10_main.py
View file @
7e9e15ad
...
@@ -32,21 +32,21 @@ import argparse
...
@@ -32,21 +32,21 @@ import argparse
import
functools
import
functools
import
itertools
import
itertools
import
os
import
os
import
six
import
numpy
as
np
from
six.moves
import
xrange
# pylint: disable=redefined-builtin
import
tensorflow
as
tf
import
cifar10
import
cifar10
import
cifar10_model
import
cifar10_model
import
cifar10_utils
import
cifar10_utils
import
numpy
as
np
import
six
from
six.moves
import
xrange
# pylint: disable=redefined-builtin
import
tensorflow
as
tf
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
def
get_model_fn
(
num_gpus
,
variable_strategy
,
num_workers
,
sync
):
def
get_model_fn
(
num_gpus
,
variable_strategy
,
num_workers
):
"""Returns a function that will build the resnet model."""
def
_resnet_model_fn
(
features
,
labels
,
mode
,
params
):
def
_resnet_model_fn
(
features
,
labels
,
mode
,
params
):
"""Resnet model body.
"""Resnet model body.
...
@@ -74,6 +74,16 @@ def get_model_fn(num_gpus, variable_strategy, num_workers, sync):
...
@@ -74,6 +74,16 @@ def get_model_fn(num_gpus, variable_strategy, num_workers, sync):
tower_gradvars
=
[]
tower_gradvars
=
[]
tower_preds
=
[]
tower_preds
=
[]
# channels first (NCHW) is normally optimal on GPU and channels last (NHWC)
# on CPU. The exception is Intel MKL on CPU which is optimal with
# channels_last.
data_format
=
params
.
data_format
if
not
data_format
:
if
num_gpus
==
0
:
data_format
=
'channels_last'
else
:
data_format
=
'channels_first'
if
num_gpus
==
0
:
if
num_gpus
==
0
:
num_devices
=
1
num_devices
=
1
device_type
=
'cpu'
device_type
=
'cpu'
...
@@ -91,21 +101,13 @@ def get_model_fn(num_gpus, variable_strategy, num_workers, sync):
...
@@ -91,21 +101,13 @@ def get_model_fn(num_gpus, variable_strategy, num_workers, sync):
ps_device_type
=
'gpu'
,
ps_device_type
=
'gpu'
,
worker_device
=
worker_device
,
worker_device
=
worker_device
,
ps_strategy
=
tf
.
contrib
.
training
.
GreedyLoadBalancingStrategy
(
ps_strategy
=
tf
.
contrib
.
training
.
GreedyLoadBalancingStrategy
(
num_gpus
,
num_gpus
,
tf
.
contrib
.
training
.
byte_size_load_fn
))
tf
.
contrib
.
training
.
byte_size_load_fn
)
)
with
tf
.
variable_scope
(
'resnet'
,
reuse
=
bool
(
i
!=
0
)):
with
tf
.
variable_scope
(
'resnet'
,
reuse
=
bool
(
i
!=
0
)):
with
tf
.
name_scope
(
'tower_%d'
%
i
)
as
name_scope
:
with
tf
.
name_scope
(
'tower_%d'
%
i
)
as
name_scope
:
with
tf
.
device
(
device_setter
):
with
tf
.
device
(
device_setter
):
loss
,
gradvars
,
preds
=
_tower_fn
(
loss
,
gradvars
,
preds
=
_tower_fn
(
is_training
,
is_training
,
weight_decay
,
tower_features
[
i
],
tower_labels
[
i
],
weight_decay
,
data_format
,
params
.
num_layers
,
params
.
batch_norm_decay
,
tower_features
[
i
],
tower_labels
[
i
],
(
device_type
==
'cpu'
),
params
.
num_layers
,
params
.
batch_norm_decay
,
params
.
batch_norm_epsilon
)
params
.
batch_norm_epsilon
)
tower_losses
.
append
(
loss
)
tower_losses
.
append
(
loss
)
tower_gradvars
.
append
(
gradvars
)
tower_gradvars
.
append
(
gradvars
)
...
@@ -136,7 +138,6 @@ def get_model_fn(num_gpus, variable_strategy, num_workers, sync):
...
@@ -136,7 +138,6 @@ def get_model_fn(num_gpus, variable_strategy, num_workers, sync):
avg_grad
=
tf
.
multiply
(
tf
.
add_n
(
grads
),
1.
/
len
(
grads
))
avg_grad
=
tf
.
multiply
(
tf
.
add_n
(
grads
),
1.
/
len
(
grads
))
gradvars
.
append
((
avg_grad
,
var
))
gradvars
.
append
((
avg_grad
,
var
))
# Device that runs the ops to apply global gradient updates.
# Device that runs the ops to apply global gradient updates.
consolidation_device
=
'/gpu:0'
if
variable_strategy
==
'GPU'
else
'/cpu:0'
consolidation_device
=
'/gpu:0'
if
variable_strategy
==
'GPU'
else
'/cpu:0'
with
tf
.
device
(
consolidation_device
):
with
tf
.
device
(
consolidation_device
):
...
@@ -159,10 +160,9 @@ def get_model_fn(num_gpus, variable_strategy, num_workers, sync):
...
@@ -159,10 +160,9 @@ def get_model_fn(num_gpus, variable_strategy, num_workers, sync):
learning_rate
=
learning_rate
,
momentum
=
momentum
)
learning_rate
=
learning_rate
,
momentum
=
momentum
)
chief_hooks
=
[]
chief_hooks
=
[]
if
sync
:
if
params
.
sync
:
optimizer
=
tf
.
train
.
SyncReplicasOptimizer
(
optimizer
=
tf
.
train
.
SyncReplicasOptimizer
(
optimizer
,
optimizer
,
replicas_to_aggregate
=
num_workers
)
replicas_to_aggregate
=
num_workers
)
sync_replicas_hook
=
optimizer
.
make_session_run_hook
(
True
)
sync_replicas_hook
=
optimizer
.
make_session_run_hook
(
True
)
chief_hooks
.
append
(
sync_replicas_hook
)
chief_hooks
.
append
(
sync_replicas_hook
)
...
@@ -182,7 +182,8 @@ def get_model_fn(num_gpus, variable_strategy, num_workers, sync):
...
@@ -182,7 +182,8 @@ def get_model_fn(num_gpus, variable_strategy, num_workers, sync):
}
}
stacked_labels
=
tf
.
concat
(
labels
,
axis
=
0
)
stacked_labels
=
tf
.
concat
(
labels
,
axis
=
0
)
metrics
=
{
metrics
=
{
'accuracy'
:
tf
.
metrics
.
accuracy
(
stacked_labels
,
predictions
[
'classes'
])
'accuracy'
:
tf
.
metrics
.
accuracy
(
stacked_labels
,
predictions
[
'classes'
])
}
}
loss
=
tf
.
reduce_mean
(
tower_losses
,
name
=
'loss'
)
loss
=
tf
.
reduce_mean
(
tower_losses
,
name
=
'loss'
)
...
@@ -193,35 +194,35 @@ def get_model_fn(num_gpus, variable_strategy, num_workers, sync):
...
@@ -193,35 +194,35 @@ def get_model_fn(num_gpus, variable_strategy, num_workers, sync):
train_op
=
train_op
,
train_op
=
train_op
,
training_chief_hooks
=
chief_hooks
,
training_chief_hooks
=
chief_hooks
,
eval_metric_ops
=
metrics
)
eval_metric_ops
=
metrics
)
return
_resnet_model_fn
return
_resnet_model_fn
def
_tower_fn
(
is_training
,
def
_tower_fn
(
is_training
,
weight_decay
,
feature
,
label
,
data_format
,
weight_decay
,
num_layers
,
batch_norm_decay
,
batch_norm_epsilon
):
feature
,
"""Build computation tower (Resnet).
label
,
is_cpu
,
num_layers
,
batch_norm_decay
,
batch_norm_epsilon
):
"""Build computation tower for each device (CPU or GPU).
Args:
Args:
is_training: true if is training graph.
is_training: true if is training graph.
weight_decay: weight regularization strength, a float.
weight_decay: weight regularization strength, a float.
feature: a Tensor.
feature: a Tensor.
label: a Tensor.
label: a Tensor.
tower_losses: a list to be appended with current tower's loss.
data_format: channels_last (NHWC) or channels_first (NCHW).
tower_gradvars: a list to be appended with current tower's gradients.
num_layers: number of layers, an int.
tower_preds: a list to be appended with current tower's predictions.
batch_norm_decay: decay for batch normalization, a float.
is_cpu: true if build tower on CPU.
batch_norm_epsilon: epsilon for batch normalization, a float.
Returns:
A tuple with the loss for the tower, the gradients and parameters, and
predictions.
"""
"""
data_format
=
'channels_last'
if
is_cpu
else
'channels_first'
model
=
cifar10_model
.
ResNetCifar10
(
model
=
cifar10_model
.
ResNetCifar10
(
num_layers
,
num_layers
,
batch_norm_decay
=
batch_norm_decay
,
batch_norm_decay
=
batch_norm_decay
,
batch_norm_epsilon
=
batch_norm_epsilon
,
batch_norm_epsilon
=
batch_norm_epsilon
,
is_training
=
is_training
,
data_format
=
data_format
)
is_training
=
is_training
,
data_format
=
data_format
)
logits
=
model
.
forward_pass
(
feature
,
input_data_format
=
'channels_last'
)
logits
=
model
.
forward_pass
(
feature
,
input_data_format
=
'channels_last'
)
tower_pred
=
{
tower_pred
=
{
'classes'
:
tf
.
argmax
(
input
=
logits
,
axis
=
1
),
'classes'
:
tf
.
argmax
(
input
=
logits
,
axis
=
1
),
...
@@ -241,13 +242,20 @@ def _tower_fn(is_training,
...
@@ -241,13 +242,20 @@ def _tower_fn(is_training,
return
tower_loss
,
zip
(
tower_grad
,
model_params
),
tower_pred
return
tower_loss
,
zip
(
tower_grad
,
model_params
),
tower_pred
def
input_fn
(
data_dir
,
subset
,
num_shards
,
batch_size
,
def
input_fn
(
data_dir
,
subset
,
num_shards
,
batch_size
,
use_distortion_for_training
=
True
):
use_distortion_for_training
=
True
):
"""Create input graph for model.
"""Create input graph for model.
Args:
Args:
data_dir: Directory where TFRecords representing the dataset are located.
subset: one of 'train', 'validate' and 'eval'.
subset: one of 'train', 'validate' and 'eval'.
num_shards: num of towers participating in data-parallel training.
num_shards: num of towers participating in data-parallel training.
batch_size: total batch size for training to be divided by the number of
shards.
use_distortion_for_training: True to use distortions.
Returns:
Returns:
two lists of tensors for features and labels, each of num_shards length.
two lists of tensors for features and labels, each of num_shards length.
"""
"""
...
@@ -276,10 +284,10 @@ def input_fn(data_dir, subset, num_shards, batch_size,
...
@@ -276,10 +284,10 @@ def input_fn(data_dir, subset, num_shards, batch_size,
return
feature_shards
,
label_shards
return
feature_shards
,
label_shards
# create
experiment
def
get_
experiment
_fn
(
data_dir
,
def
get_experiment_fn
(
data_dir
,
num_gpus
,
is_gpu_p
s
,
num_gpu
s
,
use_distortion_for_training
=
True
,
variable_strategy
,
sync
=
True
):
use_distortion_for_training
=
True
):
"""Returns an Experiment function.
"""Returns an Experiment function.
Experiments perform training on several workers in parallel,
Experiments perform training on several workers in parallel,
...
@@ -291,9 +299,9 @@ def get_experiment_fn(data_dir, num_gpus, is_gpu_ps,
...
@@ -291,9 +299,9 @@ def get_experiment_fn(data_dir, num_gpus, is_gpu_ps,
Args:
Args:
data_dir: str. Location of the data for input_fns.
data_dir: str. Location of the data for input_fns.
num_gpus: int. Number of GPUs on each worker.
num_gpus: int. Number of GPUs on each worker.
is_gpu_ps: bool. If true, average gradients on GPUs.
variable_strategy: String. CPU to use CPU as the parameter server
and GPU to use the GPUs as the parameter server.
use_distortion_for_training: bool. See cifar10.Cifar10DataSet.
use_distortion_for_training: bool. See cifar10.Cifar10DataSet.
sync: bool. If true synchronizes variable updates across workers.
Returns:
Returns:
A function (tf.estimator.RunConfig, tf.contrib.training.HParams) ->
A function (tf.estimator.RunConfig, tf.contrib.training.HParams) ->
tf.contrib.learn.Experiment.
tf.contrib.learn.Experiment.
...
@@ -302,6 +310,7 @@ def get_experiment_fn(data_dir, num_gpus, is_gpu_ps,
...
@@ -302,6 +310,7 @@ def get_experiment_fn(data_dir, num_gpus, is_gpu_ps,
methods on Experiment (train, evaluate) based on information
methods on Experiment (train, evaluate) based on information
about the current runner in `run_config`.
about the current runner in `run_config`.
"""
"""
def
_experiment_fn
(
run_config
,
hparams
):
def
_experiment_fn
(
run_config
,
hparams
):
"""Returns an Experiment."""
"""Returns an Experiment."""
# Create estimator.
# Create estimator.
...
@@ -311,28 +320,26 @@ def get_experiment_fn(data_dir, num_gpus, is_gpu_ps,
...
@@ -311,28 +320,26 @@ def get_experiment_fn(data_dir, num_gpus, is_gpu_ps,
subset
=
'train'
,
subset
=
'train'
,
num_shards
=
num_gpus
,
num_shards
=
num_gpus
,
batch_size
=
hparams
.
train_batch_size
,
batch_size
=
hparams
.
train_batch_size
,
use_distortion_for_training
=
use_distortion_for_training
use_distortion_for_training
=
use_distortion_for_training
)
)
eval_input_fn
=
functools
.
partial
(
eval_input_fn
=
functools
.
partial
(
input_fn
,
input_fn
,
data_dir
,
data_dir
,
subset
=
'eval'
,
subset
=
'eval'
,
batch_size
=
hparams
.
eval_batch_size
,
batch_size
=
hparams
.
eval_batch_size
,
num_shards
=
num_gpus
num_shards
=
num_gpus
)
)
num_eval_examples
=
cifar10
.
Cifar10DataSet
.
num_examples_per_epoch
(
'eval'
)
num_eval_examples
=
cifar10
.
Cifar10DataSet
.
num_examples_per_epoch
(
'eval'
)
if
num_eval_examples
%
hparams
.
eval_batch_size
!=
0
:
if
num_eval_examples
%
hparams
.
eval_batch_size
!=
0
:
raise
ValueError
(
'validation set size must be multiple of eval_batch_size'
)
raise
ValueError
(
'validation set size must be multiple of eval_batch_size'
)
train_steps
=
hparams
.
train_steps
train_steps
=
hparams
.
train_steps
eval_steps
=
num_eval_examples
//
hparams
.
eval_batch_size
eval_steps
=
num_eval_examples
//
hparams
.
eval_batch_size
examples_sec_hook
=
cifar10_utils
.
ExamplesPerSecondHook
(
examples_sec_hook
=
cifar10_utils
.
ExamplesPerSecondHook
(
hparams
.
train_batch_size
,
every_n_steps
=
10
)
hparams
.
train_batch_size
,
every_n_steps
=
10
)
tensors_to_log
=
{
'learning_rate'
:
'learning_rate'
,
tensors_to_log
=
{
'learning_rate'
:
'learning_rate'
,
'loss'
:
'loss'
}
'loss'
:
'loss'
}
logging_hook
=
tf
.
train
.
LoggingTensorHook
(
logging_hook
=
tf
.
train
.
LoggingTensorHook
(
tensors
=
tensors_to_log
,
every_n_iter
=
100
)
tensors
=
tensors_to_log
,
every_n_iter
=
100
)
...
@@ -340,11 +347,10 @@ def get_experiment_fn(data_dir, num_gpus, is_gpu_ps,
...
@@ -340,11 +347,10 @@ def get_experiment_fn(data_dir, num_gpus, is_gpu_ps,
hooks
=
[
logging_hook
,
examples_sec_hook
]
hooks
=
[
logging_hook
,
examples_sec_hook
]
classifier
=
tf
.
estimator
.
Estimator
(
classifier
=
tf
.
estimator
.
Estimator
(
model_fn
=
get_model_fn
(
model_fn
=
get_model_fn
(
num_gpus
,
variable_strategy
,
num_gpus
,
is_gpu_ps
,
run_config
.
num_worker_replicas
or
1
,
sync
),
run_config
.
num_worker_replicas
or
1
),
config
=
run_config
,
config
=
run_config
,
params
=
hparams
params
=
hparams
)
)
# Create experiment.
# Create experiment.
experiment
=
tf
.
contrib
.
learn
.
Experiment
(
experiment
=
tf
.
contrib
.
learn
.
Experiment
(
...
@@ -356,45 +362,31 @@ def get_experiment_fn(data_dir, num_gpus, is_gpu_ps,
...
@@ -356,45 +362,31 @@ def get_experiment_fn(data_dir, num_gpus, is_gpu_ps,
# Adding hooks to be used by the estimator on training modes
# Adding hooks to be used by the estimator on training modes
experiment
.
extend_train_hooks
(
hooks
)
experiment
.
extend_train_hooks
(
hooks
)
return
experiment
return
experiment
return
_experiment_fn
return
_experiment_fn
def
main
(
job_dir
,
def
main
(
job_dir
,
data_dir
,
num_gpus
,
variable_strategy
,
data_dir
,
use_distortion_for_training
,
log_device_placement
,
num_intra_threads
,
num_gpus
,
variable_strategy
,
use_distortion_for_training
,
log_device_placement
,
num_intra_threads
,
sync
,
**
hparams
):
**
hparams
):
# The env variable is on deprecation path, default is set to off.
# The env variable is on deprecation path, default is set to off.
os
.
environ
[
'TF_SYNC_ON_FINISH'
]
=
'0'
os
.
environ
[
'TF_SYNC_ON_FINISH'
]
=
'0'
os
.
environ
[
'TF_ENABLE_WINOGRAD_NONFUSED'
]
=
'1'
# Session configuration.
# Session configuration.
sess_config
=
tf
.
ConfigProto
(
sess_config
=
tf
.
ConfigProto
(
allow_soft_placement
=
True
,
allow_soft_placement
=
True
,
log_device_placement
=
log_device_placement
,
log_device_placement
=
log_device_placement
,
intra_op_parallelism_threads
=
num_intra_threads
,
intra_op_parallelism_threads
=
num_intra_threads
,
gpu_options
=
tf
.
GPUOptions
(
gpu_options
=
tf
.
GPUOptions
(
force_gpu_compatible
=
True
))
force_gpu_compatible
=
True
)
)
config
=
cifar10_utils
.
RunConfig
(
config
=
cifar10_utils
.
RunConfig
(
session_config
=
sess_config
,
session_config
=
sess_config
,
model_dir
=
job_dir
)
model_dir
=
job_dir
)
tf
.
contrib
.
learn
.
learn_runner
.
run
(
tf
.
contrib
.
learn
.
learn_runner
.
run
(
get_experiment_fn
(
get_experiment_fn
(
data_dir
,
num_gpus
,
variable_strategy
,
data_dir
,
use_distortion_for_training
),
num_gpus
,
variable_strategy
,
use_distortion_for_training
,
sync
),
run_config
=
config
,
run_config
=
config
,
hparams
=
tf
.
contrib
.
training
.
HParams
(
**
hparams
)
hparams
=
tf
.
contrib
.
training
.
HParams
(
**
hparams
))
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
@@ -403,63 +395,53 @@ if __name__ == '__main__':
...
@@ -403,63 +395,53 @@ if __name__ == '__main__':
'--data-dir'
,
'--data-dir'
,
type
=
str
,
type
=
str
,
required
=
True
,
required
=
True
,
help
=
'The directory where the CIFAR-10 input data is stored.'
help
=
'The directory where the CIFAR-10 input data is stored.'
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'--job-dir'
,
'--job-dir'
,
type
=
str
,
type
=
str
,
required
=
True
,
required
=
True
,
help
=
'The directory where the model will be stored.'
help
=
'The directory where the model will be stored.'
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'--variable-strategy'
,
'--variable-strategy'
,
choices
=
[
'CPU'
,
'GPU'
],
choices
=
[
'CPU'
,
'GPU'
],
type
=
str
,
type
=
str
,
default
=
'CPU'
,
default
=
'CPU'
,
help
=
'Where to locate variable operations'
help
=
'Where to locate variable operations'
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'--num-gpus'
,
'--num-gpus'
,
type
=
int
,
type
=
int
,
default
=
1
,
default
=
1
,
help
=
'The number of gpus used. Uses only CPU if set to 0.'
help
=
'The number of gpus used. Uses only CPU if set to 0.'
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'--num-layers'
,
'--num-layers'
,
type
=
int
,
type
=
int
,
default
=
44
,
default
=
44
,
help
=
'The number of layers of the model.'
help
=
'The number of layers of the model.'
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'--train-steps'
,
'--train-steps'
,
type
=
int
,
type
=
int
,
default
=
80000
,
default
=
80000
,
help
=
'The number of steps to use for training.'
help
=
'The number of steps to use for training.'
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'--train-batch-size'
,
'--train-batch-size'
,
type
=
int
,
type
=
int
,
default
=
128
,
default
=
128
,
help
=
'Batch size for training.'
help
=
'Batch size for training.'
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'--eval-batch-size'
,
'--eval-batch-size'
,
type
=
int
,
type
=
int
,
default
=
100
,
default
=
100
,
help
=
'Batch size for validation.'
help
=
'Batch size for validation.'
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'--momentum'
,
'--momentum'
,
type
=
float
,
type
=
float
,
default
=
0.9
,
default
=
0.9
,
help
=
'Momentum for MomentumOptimizer.'
help
=
'Momentum for MomentumOptimizer.'
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'--weight-decay'
,
'--weight-decay'
,
type
=
float
,
type
=
float
,
default
=
2e-4
,
default
=
2e-4
,
help
=
'Weight decay for convolutions.'
help
=
'Weight decay for convolutions.'
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'--learning-rate'
,
'--learning-rate'
,
type
=
float
,
type
=
float
,
...
@@ -468,22 +450,19 @@ if __name__ == '__main__':
...
@@ -468,22 +450,19 @@ if __name__ == '__main__':
This is the inital learning rate value. The learning rate will decrease
This is the inital learning rate value. The learning rate will decrease
during training. For more details check the model_fn implementation in
during training. For more details check the model_fn implementation in
this file.
\
this file.
\
"""
"""
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'--use-distortion-for-training'
,
'--use-distortion-for-training'
,
type
=
bool
,
type
=
bool
,
default
=
True
,
default
=
True
,
help
=
'If doing image distortion for training.'
help
=
'If doing image distortion for training.'
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'--sync'
,
'--sync'
,
action
=
'store_true'
,
action
=
'store_true'
,
default
=
False
,
default
=
False
,
help
=
"""
\
help
=
"""
\
If present when running in a distributed environment will run on sync mode.
\
If present when running in a distributed environment will run on sync mode.
\
"""
"""
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'--num-intra-threads'
,
'--num-intra-threads'
,
type
=
int
,
type
=
int
,
...
@@ -492,8 +471,7 @@ if __name__ == '__main__':
...
@@ -492,8 +471,7 @@ if __name__ == '__main__':
Number of threads to use for intra-op parallelism. When training on CPU
Number of threads to use for intra-op parallelism. When training on CPU
set to 0 to have the system pick the appropriate number or alternatively
set to 0 to have the system pick the appropriate number or alternatively
set it to the number of physical CPU cores.
\
set it to the number of physical CPU cores.
\
"""
"""
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'--num-inter-threads'
,
'--num-inter-threads'
,
type
=
int
,
type
=
int
,
...
@@ -501,34 +479,37 @@ if __name__ == '__main__':
...
@@ -501,34 +479,37 @@ if __name__ == '__main__':
help
=
"""
\
help
=
"""
\
Number of threads to use for inter-op parallelism. If set to 0, the
Number of threads to use for inter-op parallelism. If set to 0, the
system will pick an appropriate number.
\
system will pick an appropriate number.
\
"""
"""
)
)
parser
.
add_argument
(
'--data-format'
,
type
=
str
,
default
=
None
,
help
=
"""
\
If not set, the data format best for the training device is used.
Allowed values: channels_first (NCHW) channels_last (NHWC).
\
"""
)
parser
.
add_argument
(
parser
.
add_argument
(
'--log-device-placement'
,
'--log-device-placement'
,
action
=
'store_true'
,
action
=
'store_true'
,
default
=
False
,
default
=
False
,
help
=
'Whether to log device placement.'
help
=
'Whether to log device placement.'
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'--batch-norm-decay'
,
'--batch-norm-decay'
,
type
=
float
,
type
=
float
,
default
=
0.997
,
default
=
0.997
,
help
=
'Decay for batch norm.'
help
=
'Decay for batch norm.'
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'--batch-norm-epsilon'
,
'--batch-norm-epsilon'
,
type
=
float
,
type
=
float
,
default
=
1e-5
,
default
=
1e-5
,
help
=
'Epsilon for batch norm.'
help
=
'Epsilon for batch norm.'
)
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
args
.
num_gpus
<
0
:
if
args
.
num_gpus
<
0
:
raise
ValueError
(
raise
ValueError
(
'Invalid GPU count:
\"
--num-gpus
\"
must be 0 or a positive integer.'
)
'Invalid GPU count:
\"
--num-gpus
\"
must be 0 or a positive integer.'
)
if
args
.
num_gpus
==
0
and
args
.
variable_strategy
==
'GPU'
:
if
args
.
num_gpus
==
0
and
args
.
variable_strategy
==
'GPU'
:
raise
ValueError
(
raise
ValueError
(
'num-gpus=0, CPU must be used as parameter server. Set'
'num-gpus=0, CPU must be used as parameter server. Set'
'--variable-strategy=CPU.'
)
'--variable-strategy=CPU.'
)
if
(
args
.
num_layers
-
2
)
%
6
!=
0
:
if
(
args
.
num_layers
-
2
)
%
6
!=
0
:
raise
ValueError
(
'Invalid --num-layers parameter.'
)
raise
ValueError
(
'Invalid --num-layers parameter.'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment