Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
3b158095
Commit
3b158095
authored
May 07, 2018
by
Ilya Mironov
Browse files
Merge branch 'master' of
https://github.com/ilyamironov/models
parents
a90db800
be659c2f
Changes
121
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
693 additions
and
112 deletions
+693
-112
official/utils/flags/_base.py
official/utils/flags/_base.py
+140
-0
official/utils/flags/_benchmark.py
official/utils/flags/_benchmark.py
+67
-0
official/utils/flags/_conventions.py
official/utils/flags/_conventions.py
+36
-0
official/utils/flags/_misc.py
official/utils/flags/_misc.py
+50
-0
official/utils/flags/_performance.py
official/utils/flags/_performance.py
+132
-0
official/utils/flags/core.py
official/utils/flags/core.py
+84
-0
official/utils/flags/flags_test.py
official/utils/flags/flags_test.py
+31
-37
official/utils/logs/benchmark_uploader.py
official/utils/logs/benchmark_uploader.py
+15
-11
official/utils/logs/logger.py
official/utils/logs/logger.py
+45
-18
official/utils/logs/logger_test.py
official/utils/logs/logger_test.py
+26
-0
official/utils/testing/integration.py
official/utils/testing/integration.py
+6
-2
official/utils/testing/pylint.rcfile
official/utils/testing/pylint.rcfile
+1
-1
official/wide_deep/wide_deep.py
official/wide_deep/wide_deep.py
+45
-36
official/wide_deep/wide_deep_test.py
official/wide_deep/wide_deep_test.py
+5
-0
research/README.md
research/README.md
+2
-1
research/differential_privacy/__init__.py
research/differential_privacy/__init__.py
+1
-0
research/differential_privacy/multiple_teachers/__init__.py
research/differential_privacy/multiple_teachers/__init__.py
+1
-0
research/differential_privacy/multiple_teachers/analysis.py
research/differential_privacy/multiple_teachers/analysis.py
+1
-1
research/differential_privacy/pate/ICLR2018/plot_partition.py
...arch/differential_privacy/pate/ICLR2018/plot_partition.py
+1
-1
research/differential_privacy/pate/README.md
research/differential_privacy/pate/README.md
+4
-4
No files found.
official/utils/flags/_base.py
0 → 100644
View file @
3b158095
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flags which will be nearly universal across models."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
absl
import
flags
import
tensorflow
as
tf
from
official.utils.flags._conventions
import
help_wrap
from
official.utils.logs
import
hooks_helper
def
define_base
(
data_dir
=
True
,
model_dir
=
True
,
train_epochs
=
True
,
epochs_between_evals
=
True
,
stop_threshold
=
True
,
batch_size
=
True
,
multi_gpu
=
False
,
num_gpu
=
True
,
hooks
=
True
,
export_dir
=
True
):
"""Register base flags.
Args:
data_dir: Create a flag for specifying the input data directory.
model_dir: Create a flag for specifying the model file directory.
train_epochs: Create a flag to specify the number of training epochs.
epochs_between_evals: Create a flag to specify the frequency of testing.
stop_threshold: Create a flag to specify a threshold accuracy or other
eval metric which should trigger the end of training.
batch_size: Create a flag to specify the batch size.
multi_gpu: Create a flag to allow the use of all available GPUs.
num_gpu: Create a flag to specify the number of GPUs used.
hooks: Create a flag to specify hooks for logging.
export_dir: Create a flag to specify where a SavedModel should be exported.
Returns:
A list of flags for core.py to marks as key flags.
"""
key_flags
=
[]
if
data_dir
:
flags
.
DEFINE_string
(
name
=
"data_dir"
,
short_name
=
"dd"
,
default
=
"/tmp"
,
help
=
help_wrap
(
"The location of the input data."
))
key_flags
.
append
(
"data_dir"
)
if
model_dir
:
flags
.
DEFINE_string
(
name
=
"model_dir"
,
short_name
=
"md"
,
default
=
"/tmp"
,
help
=
help_wrap
(
"The location of the model checkpoint files."
))
key_flags
.
append
(
"model_dir"
)
if
train_epochs
:
flags
.
DEFINE_integer
(
name
=
"train_epochs"
,
short_name
=
"te"
,
default
=
1
,
help
=
help_wrap
(
"The number of epochs used to train."
))
key_flags
.
append
(
"train_epochs"
)
if
epochs_between_evals
:
flags
.
DEFINE_integer
(
name
=
"epochs_between_evals"
,
short_name
=
"ebe"
,
default
=
1
,
help
=
help_wrap
(
"The number of training epochs to run between "
"evaluations."
))
key_flags
.
append
(
"epochs_between_evals"
)
if
stop_threshold
:
flags
.
DEFINE_float
(
name
=
"stop_threshold"
,
short_name
=
"st"
,
default
=
None
,
help
=
help_wrap
(
"If passed, training will stop at the earlier of "
"train_epochs and when the evaluation metric is "
"greater than or equal to stop_threshold."
))
if
batch_size
:
flags
.
DEFINE_integer
(
name
=
"batch_size"
,
short_name
=
"bs"
,
default
=
32
,
help
=
help_wrap
(
"Batch size for training and evaluation."
))
key_flags
.
append
(
"batch_size"
)
assert
not
(
multi_gpu
and
num_gpu
)
if
multi_gpu
:
flags
.
DEFINE_bool
(
name
=
"multi_gpu"
,
default
=
False
,
help
=
help_wrap
(
"If set, run across all available GPUs."
))
key_flags
.
append
(
"multi_gpu"
)
if
num_gpu
:
flags
.
DEFINE_integer
(
name
=
"num_gpus"
,
short_name
=
"ng"
,
default
=
1
if
tf
.
test
.
is_gpu_available
()
else
0
,
help
=
help_wrap
(
"How many GPUs to use with the DistributionStrategies API. The "
"default is 1 if TensorFlow can detect a GPU, and 0 otherwise."
))
if
hooks
:
# Construct a pretty summary of hooks.
hook_list_str
=
(
u
"
\ufeff
Hook:
\n
"
+
u
"
\n
"
.
join
([
u
"
\ufeff
{}"
.
format
(
key
)
for
key
in
hooks_helper
.
HOOKS
]))
flags
.
DEFINE_list
(
name
=
"hooks"
,
short_name
=
"hk"
,
default
=
"LoggingTensorHook"
,
help
=
help_wrap
(
u
"A list of (case insensitive) strings to specify the names of "
u
"training hooks.
\n
{}
\n\ufeff
Example: `--hooks ProfilerHook,"
u
"ExamplesPerSecondHook`
\n
See official.utils.logs.hooks_helper "
u
"for details."
.
format
(
hook_list_str
))
)
key_flags
.
append
(
"hooks"
)
if
export_dir
:
flags
.
DEFINE_string
(
name
=
"export_dir"
,
short_name
=
"ed"
,
default
=
None
,
help
=
help_wrap
(
"If set, a SavedModel serialization of the model will "
"be exported to this directory at the end of training. "
"See the README for more details and relevant links."
)
)
key_flags
.
append
(
"export_dir"
)
return
key_flags
def
get_num_gpus
(
flags_obj
):
"""Treat num_gpus=-1 as 'use all'."""
if
flags_obj
.
num_gpus
!=
-
1
:
return
flags_obj
.
num_gpus
from
tensorflow.python.client
import
device_lib
# pylint: disable=g-import-not-at-top
local_device_protos
=
device_lib
.
list_local_devices
()
return
sum
([
1
for
d
in
local_device_protos
if
d
.
device_type
==
"GPU"
])
official/utils/flags/_benchmark.py
0 → 100644
View file @
3b158095
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flags for benchmarking models."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
absl
import
flags
from
official.utils.flags._conventions
import
help_wrap
def
define_benchmark
(
benchmark_log_dir
=
True
,
bigquery_uploader
=
True
):
"""Register benchmarking flags.
Args:
benchmark_log_dir: Create a flag to specify location for benchmark logging.
bigquery_uploader: Create flags for uploading results to BigQuery.
Returns:
A list of flags for core.py to marks as key flags.
"""
key_flags
=
[]
if
benchmark_log_dir
:
flags
.
DEFINE_string
(
name
=
"benchmark_log_dir"
,
short_name
=
"bld"
,
default
=
None
,
help
=
help_wrap
(
"The location of the benchmark logging."
)
)
if
bigquery_uploader
:
flags
.
DEFINE_string
(
name
=
"gcp_project"
,
short_name
=
"gp"
,
default
=
None
,
help
=
help_wrap
(
"The GCP project name where the benchmark will be uploaded."
))
flags
.
DEFINE_string
(
name
=
"bigquery_data_set"
,
short_name
=
"bds"
,
default
=
"test_benchmark"
,
help
=
help_wrap
(
"The Bigquery dataset name where the benchmark will be uploaded."
))
flags
.
DEFINE_string
(
name
=
"bigquery_run_table"
,
short_name
=
"brt"
,
default
=
"benchmark_run"
,
help
=
help_wrap
(
"The Bigquery table name where the benchmark run "
"information will be uploaded."
))
flags
.
DEFINE_string
(
name
=
"bigquery_metric_table"
,
short_name
=
"bmt"
,
default
=
"benchmark_metric"
,
help
=
help_wrap
(
"The Bigquery table name where the benchmark metric "
"information will be uploaded."
))
return
key_flags
official/utils/flags/_conventions.py
0 → 100644
View file @
3b158095
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Central location for shared arparse convention definitions."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
functools
from
absl
import
app
as
absl_app
from
absl
import
flags
# This codifies help string conventions and makes it easy to update them if
# necessary. Currently the only major effect is that help bodies start on the
# line after flags are listed. All flag definitions should wrap the text bodies
# with help wrap when calling DEFINE_*.
help_wrap
=
functools
.
partial
(
flags
.
text_wrap
,
length
=
80
,
indent
=
""
,
firstline_indent
=
"
\n
"
)
# Replace None with h to also allow -h
absl_app
.
HelpshortFlag
.
SHORT_NAME
=
"h"
official/utils/flags/_misc.py
0 → 100644
View file @
3b158095
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Misc flags."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
absl
import
flags
from
official.utils.flags._conventions
import
help_wrap
def
define_image
(
data_format
=
True
):
"""Register image specific flags.
Args:
data_format: Create a flag to specify image axis convention.
Returns:
A list of flags for core.py to marks as key flags.
"""
key_flags
=
[]
if
data_format
:
flags
.
DEFINE_enum
(
name
=
"data_format"
,
short_name
=
"df"
,
default
=
None
,
enum_values
=
[
"channels_first"
,
"channels_last"
],
help
=
help_wrap
(
"A flag to override the data format used in the model. "
"channels_first provides a performance boost on GPU but is not "
"always compatible with CPU. If left unspecified, the data format "
"will be chosen automatically based on whether TensorFlow was "
"built for CPU or GPU."
))
key_flags
.
append
(
"data_format"
)
return
key_flags
official/utils/flags/_performance.py
0 → 100644
View file @
3b158095
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Register flags for optimizing performance."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
multiprocessing
from
absl
import
flags
# pylint: disable=g-bad-import-order
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
official.utils.flags._conventions
import
help_wrap
# Map string to (TensorFlow dtype, default loss scale)
DTYPE_MAP
=
{
"fp16"
:
(
tf
.
float16
,
128
),
"fp32"
:
(
tf
.
float32
,
1
),
}
def
get_tf_dtype
(
flags_obj
):
return
DTYPE_MAP
[
flags_obj
.
dtype
][
0
]
def
get_loss_scale
(
flags_obj
):
if
flags_obj
.
loss_scale
is
not
None
:
return
flags_obj
.
loss_scale
return
DTYPE_MAP
[
flags_obj
.
dtype
][
1
]
def
define_performance
(
num_parallel_calls
=
True
,
inter_op
=
True
,
intra_op
=
True
,
synthetic_data
=
True
,
max_train_steps
=
True
,
dtype
=
True
):
"""Register flags for specifying performance tuning arguments.
Args:
num_parallel_calls: Create a flag to specify parallelism of data loading.
inter_op: Create a flag to allow specification of inter op threads.
intra_op: Create a flag to allow specification of intra op threads.
synthetic_data: Create a flag to allow the use of synthetic data.
max_train_steps: Create a flags to allow specification of maximum number
of training steps
dtype: Create flags for specifying dtype.
Returns:
A list of flags for core.py to marks as key flags.
"""
key_flags
=
[]
if
num_parallel_calls
:
flags
.
DEFINE_integer
(
name
=
"num_parallel_calls"
,
short_name
=
"npc"
,
default
=
multiprocessing
.
cpu_count
(),
help
=
help_wrap
(
"The number of records that are processed in parallel "
"during input processing. This can be optimized per "
"data set but for generally homogeneous data sets, "
"should be approximately the number of available CPU "
"cores. (default behavior)"
))
if
inter_op
:
flags
.
DEFINE_integer
(
name
=
"inter_op_parallelism_threads"
,
short_name
=
"inter"
,
default
=
0
,
help
=
help_wrap
(
"Number of inter_op_parallelism_threads to use for CPU. "
"See TensorFlow config.proto for details."
)
)
if
intra_op
:
flags
.
DEFINE_integer
(
name
=
"intra_op_parallelism_threads"
,
short_name
=
"intra"
,
default
=
0
,
help
=
help_wrap
(
"Number of intra_op_parallelism_threads to use for CPU. "
"See TensorFlow config.proto for details."
))
if
synthetic_data
:
flags
.
DEFINE_bool
(
name
=
"use_synthetic_data"
,
short_name
=
"synth"
,
default
=
False
,
help
=
help_wrap
(
"If set, use fake data (zeroes) instead of a real dataset. "
"This mode is useful for performance debugging, as it removes "
"input processing steps, but will not learn anything."
))
if
max_train_steps
:
flags
.
DEFINE_integer
(
name
=
"max_train_steps"
,
short_name
=
"mts"
,
default
=
None
,
help
=
help_wrap
(
"The model will stop training if the global_step reaches this "
"value. If not set, training will run until the specified number "
"of epochs have run as usual. It is generally recommended to set "
"--train_epochs=1 when using this flag."
))
if
dtype
:
flags
.
DEFINE_enum
(
name
=
"dtype"
,
short_name
=
"dt"
,
default
=
"fp32"
,
enum_values
=
DTYPE_MAP
.
keys
(),
help
=
help_wrap
(
"The TensorFlow datatype used for calculations. "
"Variables may be cast to a higher precision on a "
"case-by-case basis for numerical stability."
))
flags
.
DEFINE_integer
(
name
=
"loss_scale"
,
short_name
=
"ls"
,
default
=
None
,
help
=
help_wrap
(
"The amount to scale the loss by when the model is run. Before "
"gradients are computed, the loss is multiplied by the loss scale, "
"making all gradients loss_scale times larger. To adjust for this, "
"gradients are divided by the loss scale before being applied to "
"variables. This is mathematically equivalent to training without "
"a loss scale, but the loss scale helps avoid some intermediate "
"gradients from underflowing to zero. If not provided the default "
"for fp16 is 128 and 1 for all other dtypes."
))
loss_scale_val_msg
=
"loss_scale should be a positive integer."
@
flags
.
validator
(
flag_name
=
"loss_scale"
,
message
=
loss_scale_val_msg
)
def
_check_loss_scale
(
loss_scale
):
# pylint: disable=unused-variable
if
loss_scale
is
None
:
return
True
# null case is handled in get_loss_scale()
return
loss_scale
>
0
return
key_flags
official/utils/flags/core.py
0 → 100644
View file @
3b158095
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Public interface for flag definition.
See _example.py for detailed instructions on defining flags.
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
functools
import
sys
from
absl
import
app
as
absl_app
from
absl
import
flags
from
official.utils.flags
import
_base
from
official.utils.flags
import
_benchmark
from
official.utils.flags
import
_conventions
from
official.utils.flags
import
_misc
from
official.utils.flags
import
_performance
def
set_defaults
(
**
kwargs
):
for
key
,
value
in
kwargs
.
items
():
flags
.
FLAGS
.
set_default
(
name
=
key
,
value
=
value
)
def
parse_flags
(
argv
=
None
):
"""Reset flags and reparse. Currently only used in testing."""
flags
.
FLAGS
.
unparse_flags
()
absl_app
.
parse_flags_with_usage
(
argv
or
sys
.
argv
)
def
register_key_flags_in_core
(
f
):
"""Defines a function in core.py, and registers its key flags.
absl uses the location of a flags.declare_key_flag() to determine the context
in which a flag is key. By making all declares in core, this allows model
main functions to call flags.adopt_module_key_flags() on core and correctly
chain key flags.
Args:
f: The function to be wrapped
Returns:
The "core-defined" version of the input function.
"""
def
core_fn
(
*
args
,
**
kwargs
):
key_flags
=
f
(
*
args
,
**
kwargs
)
[
flags
.
declare_key_flag
(
fl
)
for
fl
in
key_flags
]
# pylint: disable=expression-not-assigned
return
core_fn
define_base
=
register_key_flags_in_core
(
_base
.
define_base
)
# Remove options not relevant for Eager from define_base().
define_base_eager
=
register_key_flags_in_core
(
functools
.
partial
(
_base
.
define_base
,
epochs_between_evals
=
False
,
stop_threshold
=
False
,
multi_gpu
=
False
,
hooks
=
False
))
define_benchmark
=
register_key_flags_in_core
(
_benchmark
.
define_benchmark
)
define_image
=
register_key_flags_in_core
(
_misc
.
define_image
)
define_performance
=
register_key_flags_in_core
(
_performance
.
define_performance
)
help_wrap
=
_conventions
.
help_wrap
get_num_gpus
=
_base
.
get_num_gpus
get_tf_dtype
=
_performance
.
get_tf_dtype
get_loss_scale
=
_performance
.
get_loss_scale
official/utils/
arg_parsers/parser
s_test.py
→
official/utils/
flags/flag
s_test.py
View file @
3b158095
# Copyright 201
7
The TensorFlow Authors. All Rights Reserved.
# Copyright 201
8
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -13,29 +13,28 @@
# limitations under the License.
# ==============================================================================
import
argparse
import
unittest
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
absl
import
flags
import
tensorflow
as
tf
from
official.utils.
arg_parsers
import
pars
er
s
from
official.utils.
flags
import
core
as
flags_core
# pylint: disable=g-bad-
import
-ord
er
class
TestParser
(
argparse
.
ArgumentParser
):
"""Class to test canned parser functionality."""
def
__init__
(
self
):
super
(
TestParser
,
self
).
__init__
(
parents
=
[
parsers
.
BaseParser
(),
parsers
.
PerformanceParser
(
num_parallel_calls
=
True
,
inter_op
=
True
,
intra_op
=
True
,
use_synthetic_data
=
True
),
parsers
.
ImageModelParser
(
data_format
=
True
),
parsers
.
BenchmarkParser
(
benchmark_log_dir
=
True
,
bigquery_uploader
=
True
)
])
def
define_flags
():
flags_core
.
define_base
(
multi_gpu
=
True
,
num_gpu
=
False
)
flags_core
.
define_performance
()
flags_core
.
define_image
()
flags_core
.
define_benchmark
()
class
BaseTester
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
super
(
BaseTester
,
cls
).
setUpClass
()
define_flags
()
def
test_default_setting
(
self
):
"""Test to ensure fields exist and defaults can be set.
"""
...
...
@@ -49,16 +48,15 @@ class BaseTester(unittest.TestCase):
hooks
=
[
"LoggingTensorHook"
],
num_parallel_calls
=
18
,
inter_op_parallelism_threads
=
5
,
intra_op_parallelism_thread
=
10
,
intra_op_parallelism_thread
s
=
10
,
data_format
=
"channels_first"
)
parser
=
TestParser
(
)
parser
.
set_defaults
(
**
defaults
)
flags_core
.
set_defaults
(
**
defaults
)
flags_core
.
parse_flags
(
)
namespace_vars
=
vars
(
parser
.
parse_args
([]))
for
key
,
value
in
defaults
.
items
():
assert
namespace_vars
[
key
]
==
value
assert
flags
.
FLAGS
.
get_flag_value
(
name
=
key
,
default
=
None
)
==
value
def
test_benchmark_setting
(
self
):
defaults
=
dict
(
...
...
@@ -67,40 +65,36 @@ class BaseTester(unittest.TestCase):
gcp_project
=
"project_abc"
,
)
parser
=
TestParser
(
)
parser
.
set_defaults
(
**
defaults
)
flags_core
.
set_defaults
(
**
defaults
)
flags_core
.
parse_flags
(
)
namespace_vars
=
vars
(
parser
.
parse_args
([]))
for
key
,
value
in
defaults
.
items
():
assert
namespace_vars
[
key
]
==
value
assert
flags
.
FLAGS
.
get_flag_value
(
name
=
key
,
default
=
None
)
==
value
def
test_booleans
(
self
):
"""Test to ensure boolean flags trigger as expected.
"""
parser
=
TestParser
()
namespace
=
parser
.
parse_args
([
"--multi_gpu"
,
"--use_synthetic_data"
])
flags_core
.
parse_flags
([
__file__
,
"--multi_gpu"
,
"--use_synthetic_data"
])
assert
namespace
.
multi_gpu
assert
namespace
.
use_synthetic_data
assert
flags
.
FLAGS
.
multi_gpu
assert
flags
.
FLAGS
.
use_synthetic_data
def
test_parse_dtype_info
(
self
):
parser
=
TestParser
()
for
dtype_str
,
tf_dtype
,
loss_scale
in
[[
"fp16"
,
tf
.
float16
,
128
],
[
"fp32"
,
tf
.
float32
,
1
]]:
args
=
parser
.
parse_args
([
"--dtype"
,
dtype_str
])
parsers
.
parse_dtype_info
(
args
)
flags_core
.
parse_flags
([
__file__
,
"--dtype"
,
dtype_str
])
assert
args
.
dtype
==
tf_dtype
assert
args
.
loss_scale
==
loss_scale
self
.
assert
Equal
(
flags_core
.
get_tf_dtype
(
flags
.
FLAGS
),
tf_dtype
)
self
.
assert
Equal
(
flags_core
.
get_loss_scale
(
flags
.
FLAGS
),
loss_scale
)
args
=
parser
.
parse_args
([
"--dtype"
,
dtype_str
,
"--loss_scale"
,
"5"
])
parsers
.
parse_dtype_info
(
args
)
flags_core
.
parse_flags
(
[
__file__
,
"--dtype"
,
dtype_str
,
"--loss_scale"
,
"5"
]
)
assert
args
.
loss_scale
==
5
self
.
assert
Equal
(
flags_core
.
get_loss_scale
(
flags
.
FLAGS
),
5
)
with
self
.
assertRaises
(
SystemExit
):
parser
.
parse_a
r
gs
([
"--dtype"
,
"int8"
])
flags_core
.
parse_
fl
ags
([
__file__
,
"--dtype"
,
"int8"
])
if
__name__
==
"__main__"
:
...
...
official/utils/logs/benchmark_uploader.py
View file @
3b158095
...
...
@@ -31,9 +31,13 @@ import uuid
from
google.cloud
import
bigquery
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
# pylint: disable=g-bad-import-order
from
absl
import
app
as
absl_app
from
absl
import
flags
import
tensorflow
as
tf
# pylint: enable=g-bad-import-order
from
official.utils.
arg_parsers
import
parsers
from
official.utils.
flags
import
core
as
flags_core
from
official.utils.logs
import
logger
...
...
@@ -108,22 +112,22 @@ class BigQueryUploader(object):
"Failed to upload benchmark info to bigquery: {}"
.
format
(
errors
))
def
main
(
argv
):
parser
=
parsers
.
BenchmarkParser
()
flags
=
parser
.
parse_args
(
args
=
argv
[
1
:])
if
not
flags
.
benchmark_log_dir
:
def
main
(
_
):
if
not
flags
.
FLAGS
.
benchmark_log_dir
:
print
(
"Usage: benchmark_uploader.py --benchmark_log_dir=/some/dir"
)
sys
.
exit
(
1
)
uploader
=
BigQueryUploader
(
flags
.
benchmark_log_dir
,
gcp_project
=
flags
.
gcp_project
)
flags
.
FLAGS
.
benchmark_log_dir
,
gcp_project
=
flags
.
FLAGS
.
gcp_project
)
run_id
=
str
(
uuid
.
uuid4
())
uploader
.
upload_benchmark_run
(
flags
.
bigquery_data_set
,
flags
.
bigquery_run_table
,
run_id
)
flags
.
FLAGS
.
bigquery_data_set
,
flags
.
FLAGS
.
bigquery_run_table
,
run_id
)
uploader
.
upload_metric
(
flags
.
bigquery_data_set
,
flags
.
bigquery_metric_table
,
run_id
)
flags
.
FLAGS
.
bigquery_data_set
,
flags
.
FLAGS
.
bigquery_metric_table
,
run_id
)
if
__name__
==
"__main__"
:
main
(
argv
=
sys
.
argv
)
flags_core
.
define_benchmark
()
flags
.
adopt_module_key_flags
(
flags_core
)
absl_app
.
run
(
main
=
main
)
official/utils/logs/logger.py
View file @
3b158095
...
...
@@ -109,8 +109,9 @@ class BaseBenchmarkLogger(object):
"Name %s, value %d, unit %s, global_step %d, extras %s"
,
name
,
value
,
unit
,
global_step
,
extras
)
def
log_run_info
(
self
,
model_name
):
tf
.
logging
.
info
(
"Benchmark run: %s"
,
_gather_run_info
(
model_name
))
def
log_run_info
(
self
,
model_name
,
dataset_name
,
run_params
):
tf
.
logging
.
info
(
"Benchmark run: %s"
,
_gather_run_info
(
model_name
,
dataset_name
,
run_params
))
class
BenchmarkFileLogger
(
BaseBenchmarkLogger
):
...
...
@@ -159,15 +160,18 @@ class BenchmarkFileLogger(BaseBenchmarkLogger):
tf
.
logging
.
warning
(
"Failed to dump metric to log file: "
"name %s, value %s, error %s"
,
name
,
value
,
e
)
def
log_run_info
(
self
,
model_name
):
def
log_run_info
(
self
,
model_name
,
dataset_name
,
run_params
):
"""Collect most of the TF runtime information for the local env.
The schema of the run info follows official/benchmark/datastore/schema.
Args:
model_name: string, the name of the model.
dataset_name: string, the name of dataset for training and evaluation.
run_params: dict, the dictionary of parameters for the run, it could
include hyperparameters or other params that are important for the run.
"""
run_info
=
_gather_run_info
(
model_name
)
run_info
=
_gather_run_info
(
model_name
,
dataset_name
,
run_params
)
with
tf
.
gfile
.
GFile
(
os
.
path
.
join
(
self
.
_logging_dir
,
BENCHMARK_RUN_LOG_FILE_NAME
),
"w"
)
as
f
:
...
...
@@ -179,15 +183,17 @@ class BenchmarkFileLogger(BaseBenchmarkLogger):
e
)
def
_gather_run_info
(
model_name
):
def
_gather_run_info
(
model_name
,
dataset_name
,
run_params
):
"""Collect the benchmark run information for the local environment."""
run_info
=
{
"model_name"
:
model_name
,
"dataset"
:
{
"name"
:
dataset_name
},
"machine_config"
:
{},
"run_date"
:
datetime
.
datetime
.
utcnow
().
strftime
(
_DATE_TIME_FORMAT_PATTERN
)}
_collect_tensorflow_info
(
run_info
)
_collect_tensorflow_environment_variables
(
run_info
)
_collect_run_params
(
run_info
,
run_params
)
_collect_cpu_info
(
run_info
)
_collect_gpu_info
(
run_info
)
_collect_memory_info
(
run_info
)
...
...
@@ -199,6 +205,21 @@ def _collect_tensorflow_info(run_info):
"version"
:
tf
.
VERSION
,
"git_hash"
:
tf
.
GIT_VERSION
}
def
_collect_run_params
(
run_info
,
run_params
):
"""Log the parameter information for the benchmark run."""
def
process_param
(
name
,
value
):
type_check
=
{
str
:
{
"name"
:
name
,
"string_value"
:
value
},
int
:
{
"name"
:
name
,
"long_value"
:
value
},
bool
:
{
"name"
:
name
,
"bool_value"
:
str
(
value
)},
float
:
{
"name"
:
name
,
"float_value"
:
value
},
}
return
type_check
.
get
(
type
(
value
),
{
"name"
:
name
,
"string_value"
:
str
(
value
)})
if
run_params
:
run_info
[
"run_parameters"
]
=
[
process_param
(
k
,
v
)
for
k
,
v
in
sorted
(
run_params
.
items
())]
def
_collect_tensorflow_environment_variables
(
run_info
):
run_info
[
"tensorflow_environment_variables"
]
=
[
{
"name"
:
k
,
"value"
:
v
}
...
...
@@ -213,15 +234,18 @@ def _collect_cpu_info(run_info):
cpu_info
[
"num_cores"
]
=
multiprocessing
.
cpu_count
()
# Note: cpuinfo is not installed in the TensorFlow OSS tree.
# It is installable via pip.
import
cpuinfo
# pylint: disable=g-import-not-at-top
try
:
# Note: cpuinfo is not installed in the TensorFlow OSS tree.
# It is installable via pip.
import
cpuinfo
# pylint: disable=g-import-not-at-top
info
=
cpuinfo
.
get_cpu_info
()
cpu_info
[
"cpu_info"
]
=
info
[
"brand"
]
cpu_info
[
"mhz_per_cpu"
]
=
info
[
"hz_advertised_raw"
][
0
]
/
1.0e6
info
=
cpuinfo
.
get_cpu_info
()
cpu_info
[
"cpu_info"
]
=
info
[
"brand"
]
cpu_info
[
"mhz_per_cpu"
]
=
info
[
"hz_advertised_raw"
][
0
]
/
1.0e6
run_info
[
"machine_config"
][
"cpu_info"
]
=
cpu_info
run_info
[
"machine_config"
][
"cpu_info"
]
=
cpu_info
except
ImportError
:
tf
.
logging
.
warn
(
"'cpuinfo' not imported. CPU info will not be logged."
)
def
_collect_gpu_info
(
run_info
):
...
...
@@ -243,12 +267,15 @@ def _collect_gpu_info(run_info):
def
_collect_memory_info
(
run_info
):
# Note: psutil is not installed in the TensorFlow OSS tree.
# It is installable via pip.
import
psutil
# pylint: disable=g-import-not-at-top
vmem
=
psutil
.
virtual_memory
()
run_info
[
"machine_config"
][
"memory_total"
]
=
vmem
.
total
run_info
[
"machine_config"
][
"memory_available"
]
=
vmem
.
available
try
:
# Note: psutil is not installed in the TensorFlow OSS tree.
# It is installable via pip.
import
psutil
# pylint: disable=g-import-not-at-top
vmem
=
psutil
.
virtual_memory
()
run_info
[
"machine_config"
][
"memory_total"
]
=
vmem
.
total
run_info
[
"machine_config"
][
"memory_available"
]
=
vmem
.
available
except
ImportError
:
tf
.
logging
.
warn
(
"'psutil' not imported. Memory info will not be logged."
)
def
_parse_gpu_model
(
physical_device_desc
):
...
...
official/utils/logs/logger_test.py
View file @
3b158095
...
...
@@ -180,6 +180,32 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
self
.
assertEqual
(
run_info
[
"tensorflow_version"
][
"version"
],
tf
.
VERSION
)
self
.
assertEqual
(
run_info
[
"tensorflow_version"
][
"git_hash"
],
tf
.
GIT_VERSION
)
def
test_collect_run_params
(
self
):
run_info
=
{}
run_parameters
=
{
"batch_size"
:
32
,
"synthetic_data"
:
True
,
"train_epochs"
:
100.00
,
"dtype"
:
"fp16"
,
"resnet_size"
:
50
,
"random_tensor"
:
tf
.
constant
(
2.0
)
}
logger
.
_collect_run_params
(
run_info
,
run_parameters
)
self
.
assertEqual
(
len
(
run_info
[
"run_parameters"
]),
6
)
self
.
assertEqual
(
run_info
[
"run_parameters"
][
0
],
{
"name"
:
"batch_size"
,
"long_value"
:
32
})
self
.
assertEqual
(
run_info
[
"run_parameters"
][
1
],
{
"name"
:
"dtype"
,
"string_value"
:
"fp16"
})
self
.
assertEqual
(
run_info
[
"run_parameters"
][
2
],
{
"name"
:
"random_tensor"
,
"string_value"
:
"Tensor(
\"
Const:0
\"
, shape=(), dtype=float32)"
})
self
.
assertEqual
(
run_info
[
"run_parameters"
][
3
],
{
"name"
:
"resnet_size"
,
"long_value"
:
50
})
self
.
assertEqual
(
run_info
[
"run_parameters"
][
4
],
{
"name"
:
"synthetic_data"
,
"bool_value"
:
"True"
})
self
.
assertEqual
(
run_info
[
"run_parameters"
][
5
],
{
"name"
:
"train_epochs"
,
"float_value"
:
100.00
})
def
test_collect_tensorflow_environment_variables
(
self
):
os
.
environ
[
"TF_ENABLE_WINOGRAD_NONFUSED"
]
=
"1"
os
.
environ
[
"TF_OTHER"
]
=
"2"
...
...
official/utils/testing/integration.py
View file @
3b158095
...
...
@@ -19,12 +19,15 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
shutil
import
sys
import
tempfile
from
absl
import
flags
from
official.utils.flags
import
core
as
flags_core
def
run_synthetic
(
main
,
tmp_root
,
extra_flags
=
None
,
synth
=
True
,
max_train
=
1
):
"""Performs a minimal run of a model.
...
...
@@ -55,7 +58,8 @@ def run_synthetic(main, tmp_root, extra_flags=None, synth=True, max_train=1):
args
.
extend
([
"--max_train_steps"
,
str
(
max_train
)])
try
:
main
(
args
)
flags_core
.
parse_flags
(
argv
=
args
)
main
(
flags
.
FLAGS
)
finally
:
if
os
.
path
.
exists
(
model_dir
):
shutil
.
rmtree
(
model_dir
)
official/utils/testing/pylint.rcfile
View file @
3b158095
...
...
@@ -61,7 +61,7 @@ variable-rgx=^[a-z][a-z0-9_]*$
# (useful for modules/projects where namespaces are manipulated during runtime
# and thus existing member attributes cannot be deduced by static analysis. It
# supports qualified module names, as well as Unix pattern matching.
ignored-modules=official, official.*, tensorflow, tensorflow.*, LazyLoader, google, google.cloud.*
ignored-modules=
absl, absl.*,
official, official.*, tensorflow, tensorflow.*, LazyLoader, google, google.cloud.*
[CLASSES]
...
...
official/wide_deep/wide_deep.py
View file @
3b158095
...
...
@@ -17,17 +17,18 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
argparse
import
os
import
shutil
import
sys
from
absl
import
app
as
absl_app
from
absl
import
flags
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
official.utils.
arg_parsers
import
parsers
from
official.utils.
flags
import
core
as
flags_core
from
official.utils.logs
import
hooks_helper
from
official.utils.misc
import
model_helpers
_CSV_COLUMNS
=
[
'age'
,
'workclass'
,
'fnlwgt'
,
'education'
,
'education_num'
,
'marital_status'
,
'occupation'
,
'relationship'
,
'race'
,
'gender'
,
...
...
@@ -47,6 +48,24 @@ _NUM_EXAMPLES = {
LOSS_PREFIX
=
{
'wide'
:
'linear/'
,
'deep'
:
'dnn/'
}
def
define_wide_deep_flags
():
"""Add supervised learning flags, as well as wide-deep model type."""
flags_core
.
define_base
()
flags
.
adopt_module_key_flags
(
flags_core
)
flags
.
DEFINE_enum
(
name
=
"model_type"
,
short_name
=
"mt"
,
default
=
"wide_deep"
,
enum_values
=
[
'wide'
,
'deep'
,
'wide_deep'
],
help
=
"Select model topology."
)
flags_core
.
set_defaults
(
data_dir
=
'/tmp/census_data'
,
model_dir
=
'/tmp/census_model'
,
train_epochs
=
40
,
epochs_between_evals
=
2
,
batch_size
=
40
)
def
build_model_columns
():
"""Builds a set of wide and deep feature columns."""
# Continuous columns
...
...
@@ -196,70 +215,60 @@ def export_model(model, model_type, export_dir):
model
.
export_savedmodel
(
export_dir
,
example_input_fn
)
def
main
(
argv
):
parser
=
WideDeepArgParser
()
flags
=
parser
.
parse_args
(
args
=
argv
[
1
:])
def
run_wide_deep
(
flags_obj
):
"""Run Wide-Deep training and eval loop.
Args:
flags_obj: An object containing parsed flag values.
"""
# Clean up the model directory if present
shutil
.
rmtree
(
flags
.
model_dir
,
ignore_errors
=
True
)
model
=
build_estimator
(
flags
.
model_dir
,
flags
.
model_type
)
shutil
.
rmtree
(
flags
_obj
.
model_dir
,
ignore_errors
=
True
)
model
=
build_estimator
(
flags
_obj
.
model_dir
,
flags
_obj
.
model_type
)
train_file
=
os
.
path
.
join
(
flags
.
data_dir
,
'adult.data'
)
test_file
=
os
.
path
.
join
(
flags
.
data_dir
,
'adult.test'
)
train_file
=
os
.
path
.
join
(
flags
_obj
.
data_dir
,
'adult.data'
)
test_file
=
os
.
path
.
join
(
flags
_obj
.
data_dir
,
'adult.test'
)
# Train and evaluate the model every `flags.epochs_between_evals` epochs.
def
train_input_fn
():
return
input_fn
(
train_file
,
flags
.
epochs_between_evals
,
True
,
flags
.
batch_size
)
train_file
,
flags
_obj
.
epochs_between_evals
,
True
,
flags
_obj
.
batch_size
)
def
eval_input_fn
():
return
input_fn
(
test_file
,
1
,
False
,
flags
.
batch_size
)
return
input_fn
(
test_file
,
1
,
False
,
flags
_obj
.
batch_size
)
loss_prefix
=
LOSS_PREFIX
.
get
(
flags
.
model_type
,
''
)
loss_prefix
=
LOSS_PREFIX
.
get
(
flags
_obj
.
model_type
,
''
)
train_hooks
=
hooks_helper
.
get_train_hooks
(
flags
.
hooks
,
batch_size
=
flags
.
batch_size
,
flags
_obj
.
hooks
,
batch_size
=
flags
_obj
.
batch_size
,
tensors_to_log
=
{
'average_loss'
:
loss_prefix
+
'head/truediv'
,
'loss'
:
loss_prefix
+
'head/weighted_loss/Sum'
})
# Train and evaluate the model every `flags.epochs_between_evals` epochs.
for
n
in
range
(
flags
.
train_epochs
//
flags
.
epochs_between_evals
):
for
n
in
range
(
flags
_obj
.
train_epochs
//
flags
_obj
.
epochs_between_evals
):
model
.
train
(
input_fn
=
train_input_fn
,
hooks
=
train_hooks
)
results
=
model
.
evaluate
(
input_fn
=
eval_input_fn
)
# Display evaluation metrics
print
(
'Results at epoch'
,
(
n
+
1
)
*
flags
.
epochs_between_evals
)
print
(
'Results at epoch'
,
(
n
+
1
)
*
flags
_obj
.
epochs_between_evals
)
print
(
'-'
*
60
)
for
key
in
sorted
(
results
):
print
(
'%s: %s'
%
(
key
,
results
[
key
]))
if
model_helpers
.
past_stop_threshold
(
flags
.
stop_threshold
,
results
[
'accuracy'
]):
flags
_obj
.
stop_threshold
,
results
[
'accuracy'
]):
break
# Export the model
if
flags
.
export_dir
is
not
None
:
export_model
(
model
,
flags
.
model_type
,
flags
.
export_dir
)
if
flags_obj
.
export_dir
is
not
None
:
export_model
(
model
,
flags_obj
.
model_type
,
flags_obj
.
export_dir
)
class
WideDeepArgParser
(
argparse
.
ArgumentParser
):
"""Argument parser for running the wide deep model."""
def
__init__
(
self
):
super
(
WideDeepArgParser
,
self
).
__init__
(
parents
=
[
parsers
.
BaseParser
()])
self
.
add_argument
(
'--model_type'
,
'-mt'
,
type
=
str
,
default
=
'wide_deep'
,
choices
=
[
'wide'
,
'deep'
,
'wide_deep'
],
help
=
'[default %(default)s] Valid model types: wide, deep, wide_deep.'
,
metavar
=
'<MT>'
)
self
.
set_defaults
(
data_dir
=
'/tmp/census_data'
,
model_dir
=
'/tmp/census_model'
,
train_epochs
=
40
,
epochs_between_evals
=
2
,
batch_size
=
40
)
def
main
(
_
):
run_wide_deep
(
flags
.
FLAGS
)
if
__name__
==
'__main__'
:
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
main
(
argv
=
sys
.
argv
)
define_wide_deep_flags
()
absl_app
.
run
(
main
)
official/wide_deep/wide_deep_test.py
View file @
3b158095
...
...
@@ -48,6 +48,11 @@ TEST_CSV = os.path.join(os.path.dirname(__file__), 'wide_deep_test.csv')
class
BaseTest
(
tf
.
test
.
TestCase
):
"""Tests for Wide Deep model."""
@
classmethod
def
setUpClass
(
cls
):
# pylint: disable=invalid-name
super
(
BaseTest
,
cls
).
setUpClass
()
wide_deep
.
define_wide_deep_flags
()
def
setUp
(
self
):
# Create temporary CSV file
self
.
temp_dir
=
self
.
get_temp_dir
()
...
...
research/README.md
View file @
3b158095
...
...
@@ -26,7 +26,7 @@ installation](https://www.tensorflow.org/install).
for visual navigation.
-
[
compression
](
compression
)
: compressing and decompressing images using a
pre-trained Residual GRU network.
-
[
deeplab
](
deeplab
)
: deep label
l
ing for semantic image segmentation.
-
[
deeplab
](
deeplab
)
: deep labeling for semantic image segmentation.
-
[
delf
](
delf
)
: deep local features for image matching and retrieval.
-
[
differential_privacy
](
differential_privacy
)
: differential privacy for training
data.
...
...
@@ -55,6 +55,7 @@ installation](https://www.tensorflow.org/install).
-
[
pcl_rl
](
pcl_rl
)
: code for several reinforcement learning algorithms,
including Path Consistency Learning.
-
[
ptn
](
ptn
)
: perspective transformer nets for 3D object reconstruction.
-
[
marco
](
marco
)
: automating the evaluation of crystallization experiments.
-
[
qa_kg
](
qa_kg
)
: module networks for question answering on knowledge graphs.
-
[
real_nvp
](
real_nvp
)
: density estimation using real-valued non-volume
preserving (real NVP) transformations.
...
...
research/differential_privacy/__init__.py
View file @
3b158095
research/differential_privacy/multiple_teachers/__init__.py
0 → 100644
View file @
3b158095
research/differential_privacy/multiple_teachers/analysis.py
View file @
3b158095
...
...
@@ -240,7 +240,7 @@ def main(unused_argv):
counts_mat
=
np
.
zeros
((
n
,
10
)).
astype
(
np
.
int32
)
for
i
in
range
(
n
):
for
j
in
range
(
num_teachers
):
counts_mat
[
i
,
input_mat
[
j
,
i
]]
+=
1
counts_mat
[
i
,
int
(
input_mat
[
j
,
i
]
)
]
+=
1
n
=
counts_mat
.
shape
[
0
]
num_examples
=
min
(
n
,
FLAGS
.
max_examples
)
...
...
research/differential_privacy/pate/ICLR2018/plot_partition.py
View file @
3b158095
...
...
@@ -186,7 +186,7 @@ def analyze_gnmax_conf_data_dep(votes, threshold, sigma1, sigma2, delta):
ss
=
rdp_ss
[
order_idx
],
delta
=-
math
.
log
(
delta
)
/
(
order_opt
[
i
]
-
1
))
ss_std_opt
[
i
]
=
ss_std
[
order_idx
]
if
i
>
0
and
(
i
+
1
)
%
1
0
==
0
:
if
i
>
0
and
(
i
+
1
)
%
1
==
0
:
print
(
'queries = {}, E[answered] = {:.2f}, E[eps] = {:.3f} +/- {:.3f} '
'at order = {:.2f}. Contributions: delta = {:.3f}, step1 = {:.3f}, '
'step2 = {:.3f}, ss = {:.3f}'
.
format
(
...
...
research/differential_privacy/pate/README.md
View file @
3b158095
...
...
@@ -15,7 +15,7 @@ dataset.
The framework consists of _teachers_, the _student_ model, and the _aggregator_. The
teachers are models trained on disjoint subsets of the training datasets. The student
model has access to an insensitive (
i.e
., public) unlabelled dataset, which is labelled by
model has access to an insensitive (
e.g
., public) unlabelled dataset, which is labelled by
interacting with the ensemble of teachers via the _aggregator_. The aggregator tallies
outputs of the teacher models, and either forwards a (noisy) aggregate to the student, or
refuses to answer.
...
...
@@ -57,13 +57,13 @@ $ python smooth_sensitivity_test.py
## Files in this directory
*
core.py
---
RDP privacy accountant for several vote aggregators (GNMax,
*
core.py
—
RDP privacy accountant for several vote aggregators (GNMax,
Threshold, Laplace).
*
smooth_sensitivity.py
---
Smooth sensitivity analysis for GNMax and
*
smooth_sensitivity.py
—
Smooth sensitivity analysis for GNMax and
Threshold mechanisms.
*
core_test.py and smooth_sensitivity_test.py
---
Unit tests for the
*
core_test.py and smooth_sensitivity_test.py
—
Unit tests for the
files above.
## Contact information
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment