Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
3b158095
Commit
3b158095
authored
May 07, 2018
by
Ilya Mironov
Browse files
Merge branch 'master' of
https://github.com/ilyamironov/models
parents
a90db800
be659c2f
Changes
121
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
693 additions
and
112 deletions
+693
-112
official/utils/flags/_base.py
official/utils/flags/_base.py
+140
-0
official/utils/flags/_benchmark.py
official/utils/flags/_benchmark.py
+67
-0
official/utils/flags/_conventions.py
official/utils/flags/_conventions.py
+36
-0
official/utils/flags/_misc.py
official/utils/flags/_misc.py
+50
-0
official/utils/flags/_performance.py
official/utils/flags/_performance.py
+132
-0
official/utils/flags/core.py
official/utils/flags/core.py
+84
-0
official/utils/flags/flags_test.py
official/utils/flags/flags_test.py
+31
-37
official/utils/logs/benchmark_uploader.py
official/utils/logs/benchmark_uploader.py
+15
-11
official/utils/logs/logger.py
official/utils/logs/logger.py
+45
-18
official/utils/logs/logger_test.py
official/utils/logs/logger_test.py
+26
-0
official/utils/testing/integration.py
official/utils/testing/integration.py
+6
-2
official/utils/testing/pylint.rcfile
official/utils/testing/pylint.rcfile
+1
-1
official/wide_deep/wide_deep.py
official/wide_deep/wide_deep.py
+45
-36
official/wide_deep/wide_deep_test.py
official/wide_deep/wide_deep_test.py
+5
-0
research/README.md
research/README.md
+2
-1
research/differential_privacy/__init__.py
research/differential_privacy/__init__.py
+1
-0
research/differential_privacy/multiple_teachers/__init__.py
research/differential_privacy/multiple_teachers/__init__.py
+1
-0
research/differential_privacy/multiple_teachers/analysis.py
research/differential_privacy/multiple_teachers/analysis.py
+1
-1
research/differential_privacy/pate/ICLR2018/plot_partition.py
...arch/differential_privacy/pate/ICLR2018/plot_partition.py
+1
-1
research/differential_privacy/pate/README.md
research/differential_privacy/pate/README.md
+4
-4
No files found.
official/utils/flags/_base.py
0 → 100644
View file @
3b158095
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flags which will be nearly universal across models."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
absl
import
flags
import
tensorflow
as
tf
from
official.utils.flags._conventions
import
help_wrap
from
official.utils.logs
import
hooks_helper
def
define_base
(
data_dir
=
True
,
model_dir
=
True
,
train_epochs
=
True
,
epochs_between_evals
=
True
,
stop_threshold
=
True
,
batch_size
=
True
,
multi_gpu
=
False
,
num_gpu
=
True
,
hooks
=
True
,
export_dir
=
True
):
"""Register base flags.
Args:
data_dir: Create a flag for specifying the input data directory.
model_dir: Create a flag for specifying the model file directory.
train_epochs: Create a flag to specify the number of training epochs.
epochs_between_evals: Create a flag to specify the frequency of testing.
stop_threshold: Create a flag to specify a threshold accuracy or other
eval metric which should trigger the end of training.
batch_size: Create a flag to specify the batch size.
multi_gpu: Create a flag to allow the use of all available GPUs.
num_gpu: Create a flag to specify the number of GPUs used.
hooks: Create a flag to specify hooks for logging.
export_dir: Create a flag to specify where a SavedModel should be exported.
Returns:
A list of flags for core.py to marks as key flags.
"""
key_flags
=
[]
if
data_dir
:
flags
.
DEFINE_string
(
name
=
"data_dir"
,
short_name
=
"dd"
,
default
=
"/tmp"
,
help
=
help_wrap
(
"The location of the input data."
))
key_flags
.
append
(
"data_dir"
)
if
model_dir
:
flags
.
DEFINE_string
(
name
=
"model_dir"
,
short_name
=
"md"
,
default
=
"/tmp"
,
help
=
help_wrap
(
"The location of the model checkpoint files."
))
key_flags
.
append
(
"model_dir"
)
if
train_epochs
:
flags
.
DEFINE_integer
(
name
=
"train_epochs"
,
short_name
=
"te"
,
default
=
1
,
help
=
help_wrap
(
"The number of epochs used to train."
))
key_flags
.
append
(
"train_epochs"
)
if
epochs_between_evals
:
flags
.
DEFINE_integer
(
name
=
"epochs_between_evals"
,
short_name
=
"ebe"
,
default
=
1
,
help
=
help_wrap
(
"The number of training epochs to run between "
"evaluations."
))
key_flags
.
append
(
"epochs_between_evals"
)
if
stop_threshold
:
flags
.
DEFINE_float
(
name
=
"stop_threshold"
,
short_name
=
"st"
,
default
=
None
,
help
=
help_wrap
(
"If passed, training will stop at the earlier of "
"train_epochs and when the evaluation metric is "
"greater than or equal to stop_threshold."
))
if
batch_size
:
flags
.
DEFINE_integer
(
name
=
"batch_size"
,
short_name
=
"bs"
,
default
=
32
,
help
=
help_wrap
(
"Batch size for training and evaluation."
))
key_flags
.
append
(
"batch_size"
)
assert
not
(
multi_gpu
and
num_gpu
)
if
multi_gpu
:
flags
.
DEFINE_bool
(
name
=
"multi_gpu"
,
default
=
False
,
help
=
help_wrap
(
"If set, run across all available GPUs."
))
key_flags
.
append
(
"multi_gpu"
)
if
num_gpu
:
flags
.
DEFINE_integer
(
name
=
"num_gpus"
,
short_name
=
"ng"
,
default
=
1
if
tf
.
test
.
is_gpu_available
()
else
0
,
help
=
help_wrap
(
"How many GPUs to use with the DistributionStrategies API. The "
"default is 1 if TensorFlow can detect a GPU, and 0 otherwise."
))
if
hooks
:
# Construct a pretty summary of hooks.
hook_list_str
=
(
u
"
\ufeff
Hook:
\n
"
+
u
"
\n
"
.
join
([
u
"
\ufeff
{}"
.
format
(
key
)
for
key
in
hooks_helper
.
HOOKS
]))
flags
.
DEFINE_list
(
name
=
"hooks"
,
short_name
=
"hk"
,
default
=
"LoggingTensorHook"
,
help
=
help_wrap
(
u
"A list of (case insensitive) strings to specify the names of "
u
"training hooks.
\n
{}
\n\ufeff
Example: `--hooks ProfilerHook,"
u
"ExamplesPerSecondHook`
\n
See official.utils.logs.hooks_helper "
u
"for details."
.
format
(
hook_list_str
))
)
key_flags
.
append
(
"hooks"
)
if
export_dir
:
flags
.
DEFINE_string
(
name
=
"export_dir"
,
short_name
=
"ed"
,
default
=
None
,
help
=
help_wrap
(
"If set, a SavedModel serialization of the model will "
"be exported to this directory at the end of training. "
"See the README for more details and relevant links."
)
)
key_flags
.
append
(
"export_dir"
)
return
key_flags
def
get_num_gpus
(
flags_obj
):
"""Treat num_gpus=-1 as 'use all'."""
if
flags_obj
.
num_gpus
!=
-
1
:
return
flags_obj
.
num_gpus
from
tensorflow.python.client
import
device_lib
# pylint: disable=g-import-not-at-top
local_device_protos
=
device_lib
.
list_local_devices
()
return
sum
([
1
for
d
in
local_device_protos
if
d
.
device_type
==
"GPU"
])
official/utils/flags/_benchmark.py
0 → 100644
View file @
3b158095
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flags for benchmarking models."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
absl
import
flags
from
official.utils.flags._conventions
import
help_wrap
def
define_benchmark
(
benchmark_log_dir
=
True
,
bigquery_uploader
=
True
):
"""Register benchmarking flags.
Args:
benchmark_log_dir: Create a flag to specify location for benchmark logging.
bigquery_uploader: Create flags for uploading results to BigQuery.
Returns:
A list of flags for core.py to marks as key flags.
"""
key_flags
=
[]
if
benchmark_log_dir
:
flags
.
DEFINE_string
(
name
=
"benchmark_log_dir"
,
short_name
=
"bld"
,
default
=
None
,
help
=
help_wrap
(
"The location of the benchmark logging."
)
)
if
bigquery_uploader
:
flags
.
DEFINE_string
(
name
=
"gcp_project"
,
short_name
=
"gp"
,
default
=
None
,
help
=
help_wrap
(
"The GCP project name where the benchmark will be uploaded."
))
flags
.
DEFINE_string
(
name
=
"bigquery_data_set"
,
short_name
=
"bds"
,
default
=
"test_benchmark"
,
help
=
help_wrap
(
"The Bigquery dataset name where the benchmark will be uploaded."
))
flags
.
DEFINE_string
(
name
=
"bigquery_run_table"
,
short_name
=
"brt"
,
default
=
"benchmark_run"
,
help
=
help_wrap
(
"The Bigquery table name where the benchmark run "
"information will be uploaded."
))
flags
.
DEFINE_string
(
name
=
"bigquery_metric_table"
,
short_name
=
"bmt"
,
default
=
"benchmark_metric"
,
help
=
help_wrap
(
"The Bigquery table name where the benchmark metric "
"information will be uploaded."
))
return
key_flags
official/utils/flags/_conventions.py
0 → 100644
View file @
3b158095
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Central location for shared arparse convention definitions."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
functools
from
absl
import
app
as
absl_app
from
absl
import
flags
# This codifies help string conventions and makes it easy to update them if
# necessary. Currently the only major effect is that help bodies start on the
# line after flags are listed. All flag definitions should wrap the text bodies
# with help wrap when calling DEFINE_*.
help_wrap
=
functools
.
partial
(
flags
.
text_wrap
,
length
=
80
,
indent
=
""
,
firstline_indent
=
"
\n
"
)
# Replace None with h to also allow -h
absl_app
.
HelpshortFlag
.
SHORT_NAME
=
"h"
official/utils/flags/_misc.py
0 → 100644
View file @
3b158095
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Misc flags."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
absl
import
flags
from
official.utils.flags._conventions
import
help_wrap
def
define_image
(
data_format
=
True
):
"""Register image specific flags.
Args:
data_format: Create a flag to specify image axis convention.
Returns:
A list of flags for core.py to marks as key flags.
"""
key_flags
=
[]
if
data_format
:
flags
.
DEFINE_enum
(
name
=
"data_format"
,
short_name
=
"df"
,
default
=
None
,
enum_values
=
[
"channels_first"
,
"channels_last"
],
help
=
help_wrap
(
"A flag to override the data format used in the model. "
"channels_first provides a performance boost on GPU but is not "
"always compatible with CPU. If left unspecified, the data format "
"will be chosen automatically based on whether TensorFlow was "
"built for CPU or GPU."
))
key_flags
.
append
(
"data_format"
)
return
key_flags
official/utils/flags/_performance.py
0 → 100644
View file @
3b158095
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Register flags for optimizing performance."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
multiprocessing
from
absl
import
flags
# pylint: disable=g-bad-import-order
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
official.utils.flags._conventions
import
help_wrap
# Map string to (TensorFlow dtype, default loss scale)
DTYPE_MAP
=
{
"fp16"
:
(
tf
.
float16
,
128
),
"fp32"
:
(
tf
.
float32
,
1
),
}
def
get_tf_dtype
(
flags_obj
):
return
DTYPE_MAP
[
flags_obj
.
dtype
][
0
]
def
get_loss_scale
(
flags_obj
):
if
flags_obj
.
loss_scale
is
not
None
:
return
flags_obj
.
loss_scale
return
DTYPE_MAP
[
flags_obj
.
dtype
][
1
]
def
define_performance
(
num_parallel_calls
=
True
,
inter_op
=
True
,
intra_op
=
True
,
synthetic_data
=
True
,
max_train_steps
=
True
,
dtype
=
True
):
"""Register flags for specifying performance tuning arguments.
Args:
num_parallel_calls: Create a flag to specify parallelism of data loading.
inter_op: Create a flag to allow specification of inter op threads.
intra_op: Create a flag to allow specification of intra op threads.
synthetic_data: Create a flag to allow the use of synthetic data.
max_train_steps: Create a flags to allow specification of maximum number
of training steps
dtype: Create flags for specifying dtype.
Returns:
A list of flags for core.py to marks as key flags.
"""
key_flags
=
[]
if
num_parallel_calls
:
flags
.
DEFINE_integer
(
name
=
"num_parallel_calls"
,
short_name
=
"npc"
,
default
=
multiprocessing
.
cpu_count
(),
help
=
help_wrap
(
"The number of records that are processed in parallel "
"during input processing. This can be optimized per "
"data set but for generally homogeneous data sets, "
"should be approximately the number of available CPU "
"cores. (default behavior)"
))
if
inter_op
:
flags
.
DEFINE_integer
(
name
=
"inter_op_parallelism_threads"
,
short_name
=
"inter"
,
default
=
0
,
help
=
help_wrap
(
"Number of inter_op_parallelism_threads to use for CPU. "
"See TensorFlow config.proto for details."
)
)
if
intra_op
:
flags
.
DEFINE_integer
(
name
=
"intra_op_parallelism_threads"
,
short_name
=
"intra"
,
default
=
0
,
help
=
help_wrap
(
"Number of intra_op_parallelism_threads to use for CPU. "
"See TensorFlow config.proto for details."
))
if
synthetic_data
:
flags
.
DEFINE_bool
(
name
=
"use_synthetic_data"
,
short_name
=
"synth"
,
default
=
False
,
help
=
help_wrap
(
"If set, use fake data (zeroes) instead of a real dataset. "
"This mode is useful for performance debugging, as it removes "
"input processing steps, but will not learn anything."
))
if
max_train_steps
:
flags
.
DEFINE_integer
(
name
=
"max_train_steps"
,
short_name
=
"mts"
,
default
=
None
,
help
=
help_wrap
(
"The model will stop training if the global_step reaches this "
"value. If not set, training will run until the specified number "
"of epochs have run as usual. It is generally recommended to set "
"--train_epochs=1 when using this flag."
))
if
dtype
:
flags
.
DEFINE_enum
(
name
=
"dtype"
,
short_name
=
"dt"
,
default
=
"fp32"
,
enum_values
=
DTYPE_MAP
.
keys
(),
help
=
help_wrap
(
"The TensorFlow datatype used for calculations. "
"Variables may be cast to a higher precision on a "
"case-by-case basis for numerical stability."
))
flags
.
DEFINE_integer
(
name
=
"loss_scale"
,
short_name
=
"ls"
,
default
=
None
,
help
=
help_wrap
(
"The amount to scale the loss by when the model is run. Before "
"gradients are computed, the loss is multiplied by the loss scale, "
"making all gradients loss_scale times larger. To adjust for this, "
"gradients are divided by the loss scale before being applied to "
"variables. This is mathematically equivalent to training without "
"a loss scale, but the loss scale helps avoid some intermediate "
"gradients from underflowing to zero. If not provided the default "
"for fp16 is 128 and 1 for all other dtypes."
))
loss_scale_val_msg
=
"loss_scale should be a positive integer."
@
flags
.
validator
(
flag_name
=
"loss_scale"
,
message
=
loss_scale_val_msg
)
def
_check_loss_scale
(
loss_scale
):
# pylint: disable=unused-variable
if
loss_scale
is
None
:
return
True
# null case is handled in get_loss_scale()
return
loss_scale
>
0
return
key_flags
official/utils/flags/core.py
0 → 100644
View file @
3b158095
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Public interface for flag definition.
See _example.py for detailed instructions on defining flags.
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
functools
import
sys
from
absl
import
app
as
absl_app
from
absl
import
flags
from
official.utils.flags
import
_base
from
official.utils.flags
import
_benchmark
from
official.utils.flags
import
_conventions
from
official.utils.flags
import
_misc
from
official.utils.flags
import
_performance
def
set_defaults
(
**
kwargs
):
for
key
,
value
in
kwargs
.
items
():
flags
.
FLAGS
.
set_default
(
name
=
key
,
value
=
value
)
def
parse_flags
(
argv
=
None
):
"""Reset flags and reparse. Currently only used in testing."""
flags
.
FLAGS
.
unparse_flags
()
absl_app
.
parse_flags_with_usage
(
argv
or
sys
.
argv
)
def
register_key_flags_in_core
(
f
):
"""Defines a function in core.py, and registers its key flags.
absl uses the location of a flags.declare_key_flag() to determine the context
in which a flag is key. By making all declares in core, this allows model
main functions to call flags.adopt_module_key_flags() on core and correctly
chain key flags.
Args:
f: The function to be wrapped
Returns:
The "core-defined" version of the input function.
"""
def
core_fn
(
*
args
,
**
kwargs
):
key_flags
=
f
(
*
args
,
**
kwargs
)
[
flags
.
declare_key_flag
(
fl
)
for
fl
in
key_flags
]
# pylint: disable=expression-not-assigned
return
core_fn
define_base
=
register_key_flags_in_core
(
_base
.
define_base
)
# Remove options not relevant for Eager from define_base().
define_base_eager
=
register_key_flags_in_core
(
functools
.
partial
(
_base
.
define_base
,
epochs_between_evals
=
False
,
stop_threshold
=
False
,
multi_gpu
=
False
,
hooks
=
False
))
define_benchmark
=
register_key_flags_in_core
(
_benchmark
.
define_benchmark
)
define_image
=
register_key_flags_in_core
(
_misc
.
define_image
)
define_performance
=
register_key_flags_in_core
(
_performance
.
define_performance
)
help_wrap
=
_conventions
.
help_wrap
get_num_gpus
=
_base
.
get_num_gpus
get_tf_dtype
=
_performance
.
get_tf_dtype
get_loss_scale
=
_performance
.
get_loss_scale
official/utils/
arg_parsers/parser
s_test.py
→
official/utils/
flags/flag
s_test.py
View file @
3b158095
# Copyright 201
7
The TensorFlow Authors. All Rights Reserved.
# Copyright 201
8
The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -13,29 +13,28 @@
...
@@ -13,29 +13,28 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
import
argparse
import
unittest
import
unittest
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
absl
import
flags
import
tensorflow
as
tf
from
official.utils.
arg_parsers
import
pars
er
s
from
official.utils.
flags
import
core
as
flags_core
# pylint: disable=g-bad-
import
-ord
er
class
TestParser
(
argparse
.
ArgumentParser
):
def
define_flags
():
"""Class to test canned parser functionality."""
flags_core
.
define_base
(
multi_gpu
=
True
,
num_gpu
=
False
)
flags_core
.
define_performance
()
def
__init__
(
self
):
flags_core
.
define_image
()
super
(
TestParser
,
self
).
__init__
(
parents
=
[
flags_core
.
define_benchmark
()
parsers
.
BaseParser
(),
parsers
.
PerformanceParser
(
num_parallel_calls
=
True
,
inter_op
=
True
,
intra_op
=
True
,
use_synthetic_data
=
True
),
parsers
.
ImageModelParser
(
data_format
=
True
),
parsers
.
BenchmarkParser
(
benchmark_log_dir
=
True
,
bigquery_uploader
=
True
)
])
class
BaseTester
(
unittest
.
TestCase
):
class
BaseTester
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
super
(
BaseTester
,
cls
).
setUpClass
()
define_flags
()
def
test_default_setting
(
self
):
def
test_default_setting
(
self
):
"""Test to ensure fields exist and defaults can be set.
"""Test to ensure fields exist and defaults can be set.
"""
"""
...
@@ -49,16 +48,15 @@ class BaseTester(unittest.TestCase):
...
@@ -49,16 +48,15 @@ class BaseTester(unittest.TestCase):
hooks
=
[
"LoggingTensorHook"
],
hooks
=
[
"LoggingTensorHook"
],
num_parallel_calls
=
18
,
num_parallel_calls
=
18
,
inter_op_parallelism_threads
=
5
,
inter_op_parallelism_threads
=
5
,
intra_op_parallelism_thread
=
10
,
intra_op_parallelism_thread
s
=
10
,
data_format
=
"channels_first"
data_format
=
"channels_first"
)
)
parser
=
TestParser
(
)
flags_core
.
set_defaults
(
**
defaults
)
parser
.
set_defaults
(
**
defaults
)
flags_core
.
parse_flags
(
)
namespace_vars
=
vars
(
parser
.
parse_args
([]))
for
key
,
value
in
defaults
.
items
():
for
key
,
value
in
defaults
.
items
():
assert
namespace_vars
[
key
]
==
value
assert
flags
.
FLAGS
.
get_flag_value
(
name
=
key
,
default
=
None
)
==
value
def
test_benchmark_setting
(
self
):
def
test_benchmark_setting
(
self
):
defaults
=
dict
(
defaults
=
dict
(
...
@@ -67,40 +65,36 @@ class BaseTester(unittest.TestCase):
...
@@ -67,40 +65,36 @@ class BaseTester(unittest.TestCase):
gcp_project
=
"project_abc"
,
gcp_project
=
"project_abc"
,
)
)
parser
=
TestParser
(
)
flags_core
.
set_defaults
(
**
defaults
)
parser
.
set_defaults
(
**
defaults
)
flags_core
.
parse_flags
(
)
namespace_vars
=
vars
(
parser
.
parse_args
([]))
for
key
,
value
in
defaults
.
items
():
for
key
,
value
in
defaults
.
items
():
assert
namespace_vars
[
key
]
==
value
assert
flags
.
FLAGS
.
get_flag_value
(
name
=
key
,
default
=
None
)
==
value
def
test_booleans
(
self
):
def
test_booleans
(
self
):
"""Test to ensure boolean flags trigger as expected.
"""Test to ensure boolean flags trigger as expected.
"""
"""
parser
=
TestParser
()
flags_core
.
parse_flags
([
__file__
,
"--multi_gpu"
,
"--use_synthetic_data"
])
namespace
=
parser
.
parse_args
([
"--multi_gpu"
,
"--use_synthetic_data"
])
assert
namespace
.
multi_gpu
assert
flags
.
FLAGS
.
multi_gpu
assert
namespace
.
use_synthetic_data
assert
flags
.
FLAGS
.
use_synthetic_data
def
test_parse_dtype_info
(
self
):
def
test_parse_dtype_info
(
self
):
parser
=
TestParser
()
for
dtype_str
,
tf_dtype
,
loss_scale
in
[[
"fp16"
,
tf
.
float16
,
128
],
for
dtype_str
,
tf_dtype
,
loss_scale
in
[[
"fp16"
,
tf
.
float16
,
128
],
[
"fp32"
,
tf
.
float32
,
1
]]:
[
"fp32"
,
tf
.
float32
,
1
]]:
args
=
parser
.
parse_args
([
"--dtype"
,
dtype_str
])
flags_core
.
parse_flags
([
__file__
,
"--dtype"
,
dtype_str
])
parsers
.
parse_dtype_info
(
args
)
assert
args
.
dtype
==
tf_dtype
self
.
assert
Equal
(
flags_core
.
get_tf_dtype
(
flags
.
FLAGS
),
tf_dtype
)
assert
args
.
loss_scale
==
loss_scale
self
.
assert
Equal
(
flags_core
.
get_loss_scale
(
flags
.
FLAGS
),
loss_scale
)
args
=
parser
.
parse_args
([
"--dtype"
,
dtype_str
,
"--loss_scale"
,
"5"
])
flags_core
.
parse_flags
(
parsers
.
parse_dtype_info
(
args
)
[
__file__
,
"--dtype"
,
dtype_str
,
"--loss_scale"
,
"5"
]
)
assert
args
.
loss_scale
==
5
self
.
assert
Equal
(
flags_core
.
get_loss_scale
(
flags
.
FLAGS
),
5
)
with
self
.
assertRaises
(
SystemExit
):
with
self
.
assertRaises
(
SystemExit
):
parser
.
parse_a
r
gs
([
"--dtype"
,
"int8"
])
flags_core
.
parse_
fl
ags
([
__file__
,
"--dtype"
,
"int8"
])
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
official/utils/logs/benchmark_uploader.py
View file @
3b158095
...
@@ -31,9 +31,13 @@ import uuid
...
@@ -31,9 +31,13 @@ import uuid
from
google.cloud
import
bigquery
from
google.cloud
import
bigquery
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
# pylint: disable=g-bad-import-order
from
absl
import
app
as
absl_app
from
absl
import
flags
import
tensorflow
as
tf
# pylint: enable=g-bad-import-order
from
official.utils.
arg_parsers
import
parsers
from
official.utils.
flags
import
core
as
flags_core
from
official.utils.logs
import
logger
from
official.utils.logs
import
logger
...
@@ -108,22 +112,22 @@ class BigQueryUploader(object):
...
@@ -108,22 +112,22 @@ class BigQueryUploader(object):
"Failed to upload benchmark info to bigquery: {}"
.
format
(
errors
))
"Failed to upload benchmark info to bigquery: {}"
.
format
(
errors
))
def
main
(
argv
):
def
main
(
_
):
parser
=
parsers
.
BenchmarkParser
()
if
not
flags
.
FLAGS
.
benchmark_log_dir
:
flags
=
parser
.
parse_args
(
args
=
argv
[
1
:])
if
not
flags
.
benchmark_log_dir
:
print
(
"Usage: benchmark_uploader.py --benchmark_log_dir=/some/dir"
)
print
(
"Usage: benchmark_uploader.py --benchmark_log_dir=/some/dir"
)
sys
.
exit
(
1
)
sys
.
exit
(
1
)
uploader
=
BigQueryUploader
(
uploader
=
BigQueryUploader
(
flags
.
benchmark_log_dir
,
flags
.
FLAGS
.
benchmark_log_dir
,
gcp_project
=
flags
.
gcp_project
)
gcp_project
=
flags
.
FLAGS
.
gcp_project
)
run_id
=
str
(
uuid
.
uuid4
())
run_id
=
str
(
uuid
.
uuid4
())
uploader
.
upload_benchmark_run
(
uploader
.
upload_benchmark_run
(
flags
.
bigquery_data_set
,
flags
.
bigquery_run_table
,
run_id
)
flags
.
FLAGS
.
bigquery_data_set
,
flags
.
FLAGS
.
bigquery_run_table
,
run_id
)
uploader
.
upload_metric
(
uploader
.
upload_metric
(
flags
.
bigquery_data_set
,
flags
.
bigquery_metric_table
,
run_id
)
flags
.
FLAGS
.
bigquery_data_set
,
flags
.
FLAGS
.
bigquery_metric_table
,
run_id
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
main
(
argv
=
sys
.
argv
)
flags_core
.
define_benchmark
()
flags
.
adopt_module_key_flags
(
flags_core
)
absl_app
.
run
(
main
=
main
)
official/utils/logs/logger.py
View file @
3b158095
...
@@ -109,8 +109,9 @@ class BaseBenchmarkLogger(object):
...
@@ -109,8 +109,9 @@ class BaseBenchmarkLogger(object):
"Name %s, value %d, unit %s, global_step %d, extras %s"
,
"Name %s, value %d, unit %s, global_step %d, extras %s"
,
name
,
value
,
unit
,
global_step
,
extras
)
name
,
value
,
unit
,
global_step
,
extras
)
def
log_run_info
(
self
,
model_name
):
def
log_run_info
(
self
,
model_name
,
dataset_name
,
run_params
):
tf
.
logging
.
info
(
"Benchmark run: %s"
,
_gather_run_info
(
model_name
))
tf
.
logging
.
info
(
"Benchmark run: %s"
,
_gather_run_info
(
model_name
,
dataset_name
,
run_params
))
class
BenchmarkFileLogger
(
BaseBenchmarkLogger
):
class
BenchmarkFileLogger
(
BaseBenchmarkLogger
):
...
@@ -159,15 +160,18 @@ class BenchmarkFileLogger(BaseBenchmarkLogger):
...
@@ -159,15 +160,18 @@ class BenchmarkFileLogger(BaseBenchmarkLogger):
tf
.
logging
.
warning
(
"Failed to dump metric to log file: "
tf
.
logging
.
warning
(
"Failed to dump metric to log file: "
"name %s, value %s, error %s"
,
name
,
value
,
e
)
"name %s, value %s, error %s"
,
name
,
value
,
e
)
def
log_run_info
(
self
,
model_name
):
def
log_run_info
(
self
,
model_name
,
dataset_name
,
run_params
):
"""Collect most of the TF runtime information for the local env.
"""Collect most of the TF runtime information for the local env.
The schema of the run info follows official/benchmark/datastore/schema.
The schema of the run info follows official/benchmark/datastore/schema.
Args:
Args:
model_name: string, the name of the model.
model_name: string, the name of the model.
dataset_name: string, the name of dataset for training and evaluation.
run_params: dict, the dictionary of parameters for the run, it could
include hyperparameters or other params that are important for the run.
"""
"""
run_info
=
_gather_run_info
(
model_name
)
run_info
=
_gather_run_info
(
model_name
,
dataset_name
,
run_params
)
with
tf
.
gfile
.
GFile
(
os
.
path
.
join
(
with
tf
.
gfile
.
GFile
(
os
.
path
.
join
(
self
.
_logging_dir
,
BENCHMARK_RUN_LOG_FILE_NAME
),
"w"
)
as
f
:
self
.
_logging_dir
,
BENCHMARK_RUN_LOG_FILE_NAME
),
"w"
)
as
f
:
...
@@ -179,15 +183,17 @@ class BenchmarkFileLogger(BaseBenchmarkLogger):
...
@@ -179,15 +183,17 @@ class BenchmarkFileLogger(BaseBenchmarkLogger):
e
)
e
)
def
_gather_run_info
(
model_name
):
def
_gather_run_info
(
model_name
,
dataset_name
,
run_params
):
"""Collect the benchmark run information for the local environment."""
"""Collect the benchmark run information for the local environment."""
run_info
=
{
run_info
=
{
"model_name"
:
model_name
,
"model_name"
:
model_name
,
"dataset"
:
{
"name"
:
dataset_name
},
"machine_config"
:
{},
"machine_config"
:
{},
"run_date"
:
datetime
.
datetime
.
utcnow
().
strftime
(
"run_date"
:
datetime
.
datetime
.
utcnow
().
strftime
(
_DATE_TIME_FORMAT_PATTERN
)}
_DATE_TIME_FORMAT_PATTERN
)}
_collect_tensorflow_info
(
run_info
)
_collect_tensorflow_info
(
run_info
)
_collect_tensorflow_environment_variables
(
run_info
)
_collect_tensorflow_environment_variables
(
run_info
)
_collect_run_params
(
run_info
,
run_params
)
_collect_cpu_info
(
run_info
)
_collect_cpu_info
(
run_info
)
_collect_gpu_info
(
run_info
)
_collect_gpu_info
(
run_info
)
_collect_memory_info
(
run_info
)
_collect_memory_info
(
run_info
)
...
@@ -199,6 +205,21 @@ def _collect_tensorflow_info(run_info):
...
@@ -199,6 +205,21 @@ def _collect_tensorflow_info(run_info):
"version"
:
tf
.
VERSION
,
"git_hash"
:
tf
.
GIT_VERSION
}
"version"
:
tf
.
VERSION
,
"git_hash"
:
tf
.
GIT_VERSION
}
def
_collect_run_params
(
run_info
,
run_params
):
"""Log the parameter information for the benchmark run."""
def
process_param
(
name
,
value
):
type_check
=
{
str
:
{
"name"
:
name
,
"string_value"
:
value
},
int
:
{
"name"
:
name
,
"long_value"
:
value
},
bool
:
{
"name"
:
name
,
"bool_value"
:
str
(
value
)},
float
:
{
"name"
:
name
,
"float_value"
:
value
},
}
return
type_check
.
get
(
type
(
value
),
{
"name"
:
name
,
"string_value"
:
str
(
value
)})
if
run_params
:
run_info
[
"run_parameters"
]
=
[
process_param
(
k
,
v
)
for
k
,
v
in
sorted
(
run_params
.
items
())]
def
_collect_tensorflow_environment_variables
(
run_info
):
def
_collect_tensorflow_environment_variables
(
run_info
):
run_info
[
"tensorflow_environment_variables"
]
=
[
run_info
[
"tensorflow_environment_variables"
]
=
[
{
"name"
:
k
,
"value"
:
v
}
{
"name"
:
k
,
"value"
:
v
}
...
@@ -213,15 +234,18 @@ def _collect_cpu_info(run_info):
...
@@ -213,15 +234,18 @@ def _collect_cpu_info(run_info):
cpu_info
[
"num_cores"
]
=
multiprocessing
.
cpu_count
()
cpu_info
[
"num_cores"
]
=
multiprocessing
.
cpu_count
()
# Note: cpuinfo is not installed in the TensorFlow OSS tree.
try
:
# It is installable via pip.
# Note: cpuinfo is not installed in the TensorFlow OSS tree.
import
cpuinfo
# pylint: disable=g-import-not-at-top
# It is installable via pip.
import
cpuinfo
# pylint: disable=g-import-not-at-top
info
=
cpuinfo
.
get_cpu_info
()
info
=
cpuinfo
.
get_cpu_info
()
cpu_info
[
"cpu_info"
]
=
info
[
"brand"
]
cpu_info
[
"cpu_info"
]
=
info
[
"brand"
]
cpu_info
[
"mhz_per_cpu"
]
=
info
[
"hz_advertised_raw"
][
0
]
/
1.0e6
cpu_info
[
"mhz_per_cpu"
]
=
info
[
"hz_advertised_raw"
][
0
]
/
1.0e6
run_info
[
"machine_config"
][
"cpu_info"
]
=
cpu_info
run_info
[
"machine_config"
][
"cpu_info"
]
=
cpu_info
except
ImportError
:
tf
.
logging
.
warn
(
"'cpuinfo' not imported. CPU info will not be logged."
)
def
_collect_gpu_info
(
run_info
):
def
_collect_gpu_info
(
run_info
):
...
@@ -243,12 +267,15 @@ def _collect_gpu_info(run_info):
...
@@ -243,12 +267,15 @@ def _collect_gpu_info(run_info):
def
_collect_memory_info
(
run_info
):
def
_collect_memory_info
(
run_info
):
# Note: psutil is not installed in the TensorFlow OSS tree.
try
:
# It is installable via pip.
# Note: psutil is not installed in the TensorFlow OSS tree.
import
psutil
# pylint: disable=g-import-not-at-top
# It is installable via pip.
vmem
=
psutil
.
virtual_memory
()
import
psutil
# pylint: disable=g-import-not-at-top
run_info
[
"machine_config"
][
"memory_total"
]
=
vmem
.
total
vmem
=
psutil
.
virtual_memory
()
run_info
[
"machine_config"
][
"memory_available"
]
=
vmem
.
available
run_info
[
"machine_config"
][
"memory_total"
]
=
vmem
.
total
run_info
[
"machine_config"
][
"memory_available"
]
=
vmem
.
available
except
ImportError
:
tf
.
logging
.
warn
(
"'psutil' not imported. Memory info will not be logged."
)
def
_parse_gpu_model
(
physical_device_desc
):
def
_parse_gpu_model
(
physical_device_desc
):
...
...
official/utils/logs/logger_test.py
View file @
3b158095
...
@@ -180,6 +180,32 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
...
@@ -180,6 +180,32 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
self
.
assertEqual
(
run_info
[
"tensorflow_version"
][
"version"
],
tf
.
VERSION
)
self
.
assertEqual
(
run_info
[
"tensorflow_version"
][
"version"
],
tf
.
VERSION
)
self
.
assertEqual
(
run_info
[
"tensorflow_version"
][
"git_hash"
],
tf
.
GIT_VERSION
)
self
.
assertEqual
(
run_info
[
"tensorflow_version"
][
"git_hash"
],
tf
.
GIT_VERSION
)
def
test_collect_run_params
(
self
):
run_info
=
{}
run_parameters
=
{
"batch_size"
:
32
,
"synthetic_data"
:
True
,
"train_epochs"
:
100.00
,
"dtype"
:
"fp16"
,
"resnet_size"
:
50
,
"random_tensor"
:
tf
.
constant
(
2.0
)
}
logger
.
_collect_run_params
(
run_info
,
run_parameters
)
self
.
assertEqual
(
len
(
run_info
[
"run_parameters"
]),
6
)
self
.
assertEqual
(
run_info
[
"run_parameters"
][
0
],
{
"name"
:
"batch_size"
,
"long_value"
:
32
})
self
.
assertEqual
(
run_info
[
"run_parameters"
][
1
],
{
"name"
:
"dtype"
,
"string_value"
:
"fp16"
})
self
.
assertEqual
(
run_info
[
"run_parameters"
][
2
],
{
"name"
:
"random_tensor"
,
"string_value"
:
"Tensor(
\"
Const:0
\"
, shape=(), dtype=float32)"
})
self
.
assertEqual
(
run_info
[
"run_parameters"
][
3
],
{
"name"
:
"resnet_size"
,
"long_value"
:
50
})
self
.
assertEqual
(
run_info
[
"run_parameters"
][
4
],
{
"name"
:
"synthetic_data"
,
"bool_value"
:
"True"
})
self
.
assertEqual
(
run_info
[
"run_parameters"
][
5
],
{
"name"
:
"train_epochs"
,
"float_value"
:
100.00
})
def
test_collect_tensorflow_environment_variables
(
self
):
def
test_collect_tensorflow_environment_variables
(
self
):
os
.
environ
[
"TF_ENABLE_WINOGRAD_NONFUSED"
]
=
"1"
os
.
environ
[
"TF_ENABLE_WINOGRAD_NONFUSED"
]
=
"1"
os
.
environ
[
"TF_OTHER"
]
=
"2"
os
.
environ
[
"TF_OTHER"
]
=
"2"
...
...
official/utils/testing/integration.py
View file @
3b158095
...
@@ -19,12 +19,15 @@ from __future__ import absolute_import
...
@@ -19,12 +19,15 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
os
import
os
import
shutil
import
shutil
import
sys
import
sys
import
tempfile
import
tempfile
from
absl
import
flags
from
official.utils.flags
import
core
as
flags_core
def
run_synthetic
(
main
,
tmp_root
,
extra_flags
=
None
,
synth
=
True
,
max_train
=
1
):
def
run_synthetic
(
main
,
tmp_root
,
extra_flags
=
None
,
synth
=
True
,
max_train
=
1
):
"""Performs a minimal run of a model.
"""Performs a minimal run of a model.
...
@@ -55,7 +58,8 @@ def run_synthetic(main, tmp_root, extra_flags=None, synth=True, max_train=1):
...
@@ -55,7 +58,8 @@ def run_synthetic(main, tmp_root, extra_flags=None, synth=True, max_train=1):
args
.
extend
([
"--max_train_steps"
,
str
(
max_train
)])
args
.
extend
([
"--max_train_steps"
,
str
(
max_train
)])
try
:
try
:
main
(
args
)
flags_core
.
parse_flags
(
argv
=
args
)
main
(
flags
.
FLAGS
)
finally
:
finally
:
if
os
.
path
.
exists
(
model_dir
):
if
os
.
path
.
exists
(
model_dir
):
shutil
.
rmtree
(
model_dir
)
shutil
.
rmtree
(
model_dir
)
official/utils/testing/pylint.rcfile
View file @
3b158095
...
@@ -61,7 +61,7 @@ variable-rgx=^[a-z][a-z0-9_]*$
...
@@ -61,7 +61,7 @@ variable-rgx=^[a-z][a-z0-9_]*$
# (useful for modules/projects where namespaces are manipulated during runtime
# (useful for modules/projects where namespaces are manipulated during runtime
# and thus existing member attributes cannot be deduced by static analysis. It
# and thus existing member attributes cannot be deduced by static analysis. It
# supports qualified module names, as well as Unix pattern matching.
# supports qualified module names, as well as Unix pattern matching.
ignored-modules=official, official.*, tensorflow, tensorflow.*, LazyLoader, google, google.cloud.*
ignored-modules=
absl, absl.*,
official, official.*, tensorflow, tensorflow.*, LazyLoader, google, google.cloud.*
[CLASSES]
[CLASSES]
...
...
official/wide_deep/wide_deep.py
View file @
3b158095
...
@@ -17,17 +17,18 @@ from __future__ import absolute_import
...
@@ -17,17 +17,18 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
argparse
import
os
import
os
import
shutil
import
shutil
import
sys
from
absl
import
app
as
absl_app
from
absl
import
flags
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
official.utils.
arg_parsers
import
parsers
from
official.utils.
flags
import
core
as
flags_core
from
official.utils.logs
import
hooks_helper
from
official.utils.logs
import
hooks_helper
from
official.utils.misc
import
model_helpers
from
official.utils.misc
import
model_helpers
_CSV_COLUMNS
=
[
_CSV_COLUMNS
=
[
'age'
,
'workclass'
,
'fnlwgt'
,
'education'
,
'education_num'
,
'age'
,
'workclass'
,
'fnlwgt'
,
'education'
,
'education_num'
,
'marital_status'
,
'occupation'
,
'relationship'
,
'race'
,
'gender'
,
'marital_status'
,
'occupation'
,
'relationship'
,
'race'
,
'gender'
,
...
@@ -47,6 +48,24 @@ _NUM_EXAMPLES = {
...
@@ -47,6 +48,24 @@ _NUM_EXAMPLES = {
LOSS_PREFIX
=
{
'wide'
:
'linear/'
,
'deep'
:
'dnn/'
}
LOSS_PREFIX
=
{
'wide'
:
'linear/'
,
'deep'
:
'dnn/'
}
def
define_wide_deep_flags
():
"""Add supervised learning flags, as well as wide-deep model type."""
flags_core
.
define_base
()
flags
.
adopt_module_key_flags
(
flags_core
)
flags
.
DEFINE_enum
(
name
=
"model_type"
,
short_name
=
"mt"
,
default
=
"wide_deep"
,
enum_values
=
[
'wide'
,
'deep'
,
'wide_deep'
],
help
=
"Select model topology."
)
flags_core
.
set_defaults
(
data_dir
=
'/tmp/census_data'
,
model_dir
=
'/tmp/census_model'
,
train_epochs
=
40
,
epochs_between_evals
=
2
,
batch_size
=
40
)
def
build_model_columns
():
def
build_model_columns
():
"""Builds a set of wide and deep feature columns."""
"""Builds a set of wide and deep feature columns."""
# Continuous columns
# Continuous columns
...
@@ -196,70 +215,60 @@ def export_model(model, model_type, export_dir):
...
@@ -196,70 +215,60 @@ def export_model(model, model_type, export_dir):
model
.
export_savedmodel
(
export_dir
,
example_input_fn
)
model
.
export_savedmodel
(
export_dir
,
example_input_fn
)
def
main
(
argv
):
def
run_wide_deep
(
flags_obj
):
parser
=
WideDeepArgParser
()
"""Run Wide-Deep training and eval loop.
flags
=
parser
.
parse_args
(
args
=
argv
[
1
:])
Args:
flags_obj: An object containing parsed flag values.
"""
# Clean up the model directory if present
# Clean up the model directory if present
shutil
.
rmtree
(
flags
.
model_dir
,
ignore_errors
=
True
)
shutil
.
rmtree
(
flags
_obj
.
model_dir
,
ignore_errors
=
True
)
model
=
build_estimator
(
flags
.
model_dir
,
flags
.
model_type
)
model
=
build_estimator
(
flags
_obj
.
model_dir
,
flags
_obj
.
model_type
)
train_file
=
os
.
path
.
join
(
flags
.
data_dir
,
'adult.data'
)
train_file
=
os
.
path
.
join
(
flags
_obj
.
data_dir
,
'adult.data'
)
test_file
=
os
.
path
.
join
(
flags
.
data_dir
,
'adult.test'
)
test_file
=
os
.
path
.
join
(
flags
_obj
.
data_dir
,
'adult.test'
)
# Train and evaluate the model every `flags.epochs_between_evals` epochs.
# Train and evaluate the model every `flags.epochs_between_evals` epochs.
def
train_input_fn
():
def
train_input_fn
():
return
input_fn
(
return
input_fn
(
train_file
,
flags
.
epochs_between_evals
,
True
,
flags
.
batch_size
)
train_file
,
flags
_obj
.
epochs_between_evals
,
True
,
flags
_obj
.
batch_size
)
def
eval_input_fn
():
def
eval_input_fn
():
return
input_fn
(
test_file
,
1
,
False
,
flags
.
batch_size
)
return
input_fn
(
test_file
,
1
,
False
,
flags
_obj
.
batch_size
)
loss_prefix
=
LOSS_PREFIX
.
get
(
flags
.
model_type
,
''
)
loss_prefix
=
LOSS_PREFIX
.
get
(
flags
_obj
.
model_type
,
''
)
train_hooks
=
hooks_helper
.
get_train_hooks
(
train_hooks
=
hooks_helper
.
get_train_hooks
(
flags
.
hooks
,
batch_size
=
flags
.
batch_size
,
flags
_obj
.
hooks
,
batch_size
=
flags
_obj
.
batch_size
,
tensors_to_log
=
{
'average_loss'
:
loss_prefix
+
'head/truediv'
,
tensors_to_log
=
{
'average_loss'
:
loss_prefix
+
'head/truediv'
,
'loss'
:
loss_prefix
+
'head/weighted_loss/Sum'
})
'loss'
:
loss_prefix
+
'head/weighted_loss/Sum'
})
# Train and evaluate the model every `flags.epochs_between_evals` epochs.
# Train and evaluate the model every `flags.epochs_between_evals` epochs.
for
n
in
range
(
flags
.
train_epochs
//
flags
.
epochs_between_evals
):
for
n
in
range
(
flags
_obj
.
train_epochs
//
flags
_obj
.
epochs_between_evals
):
model
.
train
(
input_fn
=
train_input_fn
,
hooks
=
train_hooks
)
model
.
train
(
input_fn
=
train_input_fn
,
hooks
=
train_hooks
)
results
=
model
.
evaluate
(
input_fn
=
eval_input_fn
)
results
=
model
.
evaluate
(
input_fn
=
eval_input_fn
)
# Display evaluation metrics
# Display evaluation metrics
print
(
'Results at epoch'
,
(
n
+
1
)
*
flags
.
epochs_between_evals
)
print
(
'Results at epoch'
,
(
n
+
1
)
*
flags
_obj
.
epochs_between_evals
)
print
(
'-'
*
60
)
print
(
'-'
*
60
)
for
key
in
sorted
(
results
):
for
key
in
sorted
(
results
):
print
(
'%s: %s'
%
(
key
,
results
[
key
]))
print
(
'%s: %s'
%
(
key
,
results
[
key
]))
if
model_helpers
.
past_stop_threshold
(
if
model_helpers
.
past_stop_threshold
(
flags
.
stop_threshold
,
results
[
'accuracy'
]):
flags
_obj
.
stop_threshold
,
results
[
'accuracy'
]):
break
break
# Export the model
# Export the model
if
flags
.
export_dir
is
not
None
:
if
flags_obj
.
export_dir
is
not
None
:
export_model
(
model
,
flags
.
model_type
,
flags
.
export_dir
)
export_model
(
model
,
flags_obj
.
model_type
,
flags_obj
.
export_dir
)
class
WideDeepArgParser
(
argparse
.
ArgumentParser
):
"""Argument parser for running the wide deep model."""
def
__init__
(
self
):
def
main
(
_
):
super
(
WideDeepArgParser
,
self
).
__init__
(
parents
=
[
parsers
.
BaseParser
()])
run_wide_deep
(
flags
.
FLAGS
)
self
.
add_argument
(
'--model_type'
,
'-mt'
,
type
=
str
,
default
=
'wide_deep'
,
choices
=
[
'wide'
,
'deep'
,
'wide_deep'
],
help
=
'[default %(default)s] Valid model types: wide, deep, wide_deep.'
,
metavar
=
'<MT>'
)
self
.
set_defaults
(
data_dir
=
'/tmp/census_data'
,
model_dir
=
'/tmp/census_model'
,
train_epochs
=
40
,
epochs_between_evals
=
2
,
batch_size
=
40
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
main
(
argv
=
sys
.
argv
)
define_wide_deep_flags
()
absl_app
.
run
(
main
)
official/wide_deep/wide_deep_test.py
View file @
3b158095
...
@@ -48,6 +48,11 @@ TEST_CSV = os.path.join(os.path.dirname(__file__), 'wide_deep_test.csv')
...
@@ -48,6 +48,11 @@ TEST_CSV = os.path.join(os.path.dirname(__file__), 'wide_deep_test.csv')
class
BaseTest
(
tf
.
test
.
TestCase
):
class
BaseTest
(
tf
.
test
.
TestCase
):
"""Tests for Wide Deep model."""
"""Tests for Wide Deep model."""
@
classmethod
def
setUpClass
(
cls
):
# pylint: disable=invalid-name
super
(
BaseTest
,
cls
).
setUpClass
()
wide_deep
.
define_wide_deep_flags
()
def
setUp
(
self
):
def
setUp
(
self
):
# Create temporary CSV file
# Create temporary CSV file
self
.
temp_dir
=
self
.
get_temp_dir
()
self
.
temp_dir
=
self
.
get_temp_dir
()
...
...
research/README.md
View file @
3b158095
...
@@ -26,7 +26,7 @@ installation](https://www.tensorflow.org/install).
...
@@ -26,7 +26,7 @@ installation](https://www.tensorflow.org/install).
for visual navigation.
for visual navigation.
-
[
compression
](
compression
)
: compressing and decompressing images using a
-
[
compression
](
compression
)
: compressing and decompressing images using a
pre-trained Residual GRU network.
pre-trained Residual GRU network.
-
[
deeplab
](
deeplab
)
: deep label
l
ing for semantic image segmentation.
-
[
deeplab
](
deeplab
)
: deep labeling for semantic image segmentation.
-
[
delf
](
delf
)
: deep local features for image matching and retrieval.
-
[
delf
](
delf
)
: deep local features for image matching and retrieval.
-
[
differential_privacy
](
differential_privacy
)
: differential privacy for training
-
[
differential_privacy
](
differential_privacy
)
: differential privacy for training
data.
data.
...
@@ -55,6 +55,7 @@ installation](https://www.tensorflow.org/install).
...
@@ -55,6 +55,7 @@ installation](https://www.tensorflow.org/install).
-
[
pcl_rl
](
pcl_rl
)
: code for several reinforcement learning algorithms,
-
[
pcl_rl
](
pcl_rl
)
: code for several reinforcement learning algorithms,
including Path Consistency Learning.
including Path Consistency Learning.
-
[
ptn
](
ptn
)
: perspective transformer nets for 3D object reconstruction.
-
[
ptn
](
ptn
)
: perspective transformer nets for 3D object reconstruction.
-
[
marco
](
marco
)
: automating the evaluation of crystallization experiments.
-
[
qa_kg
](
qa_kg
)
: module networks for question answering on knowledge graphs.
-
[
qa_kg
](
qa_kg
)
: module networks for question answering on knowledge graphs.
-
[
real_nvp
](
real_nvp
)
: density estimation using real-valued non-volume
-
[
real_nvp
](
real_nvp
)
: density estimation using real-valued non-volume
preserving (real NVP) transformations.
preserving (real NVP) transformations.
...
...
research/differential_privacy/__init__.py
View file @
3b158095
research/differential_privacy/multiple_teachers/__init__.py
0 → 100644
View file @
3b158095
research/differential_privacy/multiple_teachers/analysis.py
View file @
3b158095
...
@@ -240,7 +240,7 @@ def main(unused_argv):
...
@@ -240,7 +240,7 @@ def main(unused_argv):
counts_mat
=
np
.
zeros
((
n
,
10
)).
astype
(
np
.
int32
)
counts_mat
=
np
.
zeros
((
n
,
10
)).
astype
(
np
.
int32
)
for
i
in
range
(
n
):
for
i
in
range
(
n
):
for
j
in
range
(
num_teachers
):
for
j
in
range
(
num_teachers
):
counts_mat
[
i
,
input_mat
[
j
,
i
]]
+=
1
counts_mat
[
i
,
int
(
input_mat
[
j
,
i
]
)
]
+=
1
n
=
counts_mat
.
shape
[
0
]
n
=
counts_mat
.
shape
[
0
]
num_examples
=
min
(
n
,
FLAGS
.
max_examples
)
num_examples
=
min
(
n
,
FLAGS
.
max_examples
)
...
...
research/differential_privacy/pate/ICLR2018/plot_partition.py
View file @
3b158095
...
@@ -186,7 +186,7 @@ def analyze_gnmax_conf_data_dep(votes, threshold, sigma1, sigma2, delta):
...
@@ -186,7 +186,7 @@ def analyze_gnmax_conf_data_dep(votes, threshold, sigma1, sigma2, delta):
ss
=
rdp_ss
[
order_idx
],
ss
=
rdp_ss
[
order_idx
],
delta
=-
math
.
log
(
delta
)
/
(
order_opt
[
i
]
-
1
))
delta
=-
math
.
log
(
delta
)
/
(
order_opt
[
i
]
-
1
))
ss_std_opt
[
i
]
=
ss_std
[
order_idx
]
ss_std_opt
[
i
]
=
ss_std
[
order_idx
]
if
i
>
0
and
(
i
+
1
)
%
1
0
==
0
:
if
i
>
0
and
(
i
+
1
)
%
1
==
0
:
print
(
'queries = {}, E[answered] = {:.2f}, E[eps] = {:.3f} +/- {:.3f} '
print
(
'queries = {}, E[answered] = {:.2f}, E[eps] = {:.3f} +/- {:.3f} '
'at order = {:.2f}. Contributions: delta = {:.3f}, step1 = {:.3f}, '
'at order = {:.2f}. Contributions: delta = {:.3f}, step1 = {:.3f}, '
'step2 = {:.3f}, ss = {:.3f}'
.
format
(
'step2 = {:.3f}, ss = {:.3f}'
.
format
(
...
...
research/differential_privacy/pate/README.md
View file @
3b158095
...
@@ -15,7 +15,7 @@ dataset.
...
@@ -15,7 +15,7 @@ dataset.
The framework consists of _teachers_, the _student_ model, and the _aggregator_. The
The framework consists of _teachers_, the _student_ model, and the _aggregator_. The
teachers are models trained on disjoint subsets of the training datasets. The student
teachers are models trained on disjoint subsets of the training datasets. The student
model has access to an insensitive (
i.e
., public) unlabelled dataset, which is labelled by
model has access to an insensitive (
e.g
., public) unlabelled dataset, which is labelled by
interacting with the ensemble of teachers via the _aggregator_. The aggregator tallies
interacting with the ensemble of teachers via the _aggregator_. The aggregator tallies
outputs of the teacher models, and either forwards a (noisy) aggregate to the student, or
outputs of the teacher models, and either forwards a (noisy) aggregate to the student, or
refuses to answer.
refuses to answer.
...
@@ -57,13 +57,13 @@ $ python smooth_sensitivity_test.py
...
@@ -57,13 +57,13 @@ $ python smooth_sensitivity_test.py
## Files in this directory
## Files in this directory
*
core.py
---
RDP privacy accountant for several vote aggregators (GNMax,
*
core.py
—
RDP privacy accountant for several vote aggregators (GNMax,
Threshold, Laplace).
Threshold, Laplace).
*
smooth_sensitivity.py
---
Smooth sensitivity analysis for GNMax and
*
smooth_sensitivity.py
—
Smooth sensitivity analysis for GNMax and
Threshold mechanisms.
Threshold mechanisms.
*
core_test.py and smooth_sensitivity_test.py
---
Unit tests for the
*
core_test.py and smooth_sensitivity_test.py
—
Unit tests for the
files above.
files above.
## Contact information
## Contact information
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment