Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
b4cb2454
Commit
b4cb2454
authored
Oct 21, 2016
by
Xin Pan
Committed by
GitHub
Oct 21, 2016
Browse files
Merge pull request #567 from panyx0718/master
Add differential privacy training.
parents
31559b69
107e72cc
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
2208 additions
and
0 deletions
+2208
-0
differential_privacy/README.md
differential_privacy/README.md
+93
-0
differential_privacy/__init__.py
differential_privacy/__init__.py
+0
-0
differential_privacy/dp_mnist/BUILD
differential_privacy/dp_mnist/BUILD
+24
-0
differential_privacy/dp_mnist/dp_mnist.py
differential_privacy/dp_mnist/dp_mnist.py
+506
-0
differential_privacy/dp_optimizer/BUILD
differential_privacy/dp_optimizer/BUILD
+54
-0
differential_privacy/dp_optimizer/accountant.py
differential_privacy/dp_optimizer/accountant.py
+402
-0
differential_privacy/dp_optimizer/dp_optimizer.py
differential_privacy/dp_optimizer/dp_optimizer.py
+241
-0
differential_privacy/dp_optimizer/dp_pca.py
differential_privacy/dp_optimizer/dp_pca.py
+65
-0
differential_privacy/dp_optimizer/sanitizer.py
differential_privacy/dp_optimizer/sanitizer.py
+123
-0
differential_privacy/dp_optimizer/utils.py
differential_privacy/dp_optimizer/utils.py
+312
-0
differential_privacy/per_example_gradients/BUILD
differential_privacy/per_example_gradients/BUILD
+22
-0
differential_privacy/per_example_gradients/__init__.py
differential_privacy/per_example_gradients/__init__.py
+0
-0
differential_privacy/per_example_gradients/per_example_gradients.py
...al_privacy/per_example_gradients/per_example_gradients.py
+366
-0
No files found.
differential_privacy/README.md
0 → 100644
View file @
b4cb2454
<font
size=
4
><b>
Deep Learning with Differential Privacy
</b></font>
Authors:
Martín Abadi, Andy Chu, Ian Goodfellow, H. Brendan McMahan, Ilya Mironov, Kunal Talwar, Li Zhang
Open Sourced By: Xin Pan (xpan@google.com, github: panyx0718)
<Introduction>
Machine learning techniques based on neural networks are achieving remarkable
results in a wide variety of domains. Often, the training of models requires
large, representative datasets, which may be crowdsourced and contain sensitive
information. The models should not expose private information in these datasets.
Addressing this goal, we develop new algorithmic techniques for learning and a
refined analysis of privacy costs within the framework of differential privacy.
Our implementation and experiments demonstrate that we can train deep neural
networks with non-convex objectives, under a modest privacy budget, and at a
manageable cost in software complexity, training efficiency, and model quality.
paper: https://arxiv.org/abs/1607.00133
<b>
Requirements:
</b>
1.
Tensorflow 0.10.0 (master branch)
Note: r0.11 might experience some problems
2.
Bazel 0.3.1
3.
Download MNIST data
TODO(xpan): Complete the link:
[
train
](
http://download.tensorflow.org/models/
)
[
test
](
http://download.tensorflow.org/models/
)
Alternatively, download the tfrecord format MNIST from:
https://github.com/panyx0718/models/tree/master/slim
<b>
How to run:
</b>
```
shell
# Clone the codes under differential_privacy.
# Create an empty WORKSPACE file.
# Download the data to the data/ directory.
ls
-R
.:
data differential_privacy WORKSPACE
./data:
mnist_test.tfrecord mnist_train.tfrecord
./differential_privacy:
dp_mnist dp_optimizer __init__.py per_example_gradients README.md
./differential_privacy/dp_mnist:
BUILD dp_mnist.py
./differential_privacy/dp_optimizer:
accountant.py BUILD dp_optimizer.py dp_pca.py sanitizer.py utils.py
./differential_privacy/per_example_gradients:
BUILD __init__.py per_example_gradients.py
# Build the codes.
bazel build
-c
opt differential_privacy/...
# Run the mnist differntial privacy training codes.
bazel-bin/differential_privacy/dp_mnist/dp_mnist
\
--training_data_path
=
data/mnist_train.tfrecord
\
--eval_data_path
=
data/mnist_test.tfrecord
\
--save_path
=
/tmp/mnist_dir
...
step: 1
step: 2
...
step: 9
spent privacy: eps 0.1250 delta 0.72709
spent privacy: eps 0.2500 delta 0.24708
spent privacy: eps 0.5000 delta 0.0029139
spent privacy: eps 1.0000 delta 6.494e-10
spent privacy: eps 2.0000 delta 8.2242e-24
spent privacy: eps 4.0000 delta 1.319e-51
spent privacy: eps 8.0000 delta 3.3927e-107
train_accuracy: 0.53
eval_accuracy: 0.53
...
ls
/tmp/mnist_dir/
checkpoint ckpt ckpt.meta results-0.json
```
differential_privacy/__init__.py
0 → 100644
View file @
b4cb2454
differential_privacy/dp_mnist/BUILD
0 → 100644
View file @
b4cb2454
package
(
default_visibility
=
[
":internal"
])
licenses
([
"notice"
])
# Apache 2.0
exports_files
([
"LICENSE"
])
package_group
(
name
=
"internal"
,
packages
=
[
"//differential_privacy/..."
,
],
)
py_binary
(
name
=
"dp_mnist"
,
srcs
=
[
"dp_mnist.py"
,
],
deps
=
[
"//differential_privacy/dp_optimizer"
,
"//differential_privacy/dp_optimizer:dp_pca"
,
"//differential_privacy/dp_optimizer:utils"
,
],
)
differential_privacy/dp_mnist/dp_mnist.py
0 → 100644
View file @
b4cb2454
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Example differentially private trainer and evaluator for MNIST.
"""
from
__future__
import
division
import
json
import
os
import
sys
import
time
import
numpy
as
np
import
tensorflow
as
tf
from
differential_privacy.dp_optimizer
import
accountant
from
differential_privacy.dp_optimizer
import
dp_optimizer
from
differential_privacy.dp_optimizer
import
dp_pca
from
differential_privacy.dp_optimizer
import
sanitizer
from
differential_privacy.dp_optimizer
import
utils
# parameters for the training
tf
.
flags
.
DEFINE_integer
(
"batch_size"
,
600
,
"The training batch size."
)
tf
.
flags
.
DEFINE_integer
(
"batches_per_lot"
,
1
,
"Number of batches per lot."
)
# Together, batch_size and batches_per_lot determine lot_size.
tf
.
flags
.
DEFINE_integer
(
"num_training_steps"
,
50000
,
"The number of training steps."
"This counts number of lots."
)
tf
.
flags
.
DEFINE_bool
(
"randomize"
,
True
,
"If true, randomize the input data; otherwise use a fixed "
"seed and non-randomized input."
)
tf
.
flags
.
DEFINE_bool
(
"freeze_bottom_layers"
,
False
,
"If true, only train on the logit layer."
)
tf
.
flags
.
DEFINE_bool
(
"save_mistakes"
,
False
,
"If true, save the mistakes made during testing."
)
tf
.
flags
.
DEFINE_float
(
"lr"
,
0.05
,
"start learning rate"
)
tf
.
flags
.
DEFINE_float
(
"end_lr"
,
0.05
,
"end learning rate"
)
tf
.
flags
.
DEFINE_float
(
"lr_saturate_epochs"
,
0
,
"learning rate saturate epochs; set to 0 for a constant "
"learning rate of --lr."
)
# For searching parameters
tf
.
flags
.
DEFINE_integer
(
"projection_dimensions"
,
60
,
"PCA projection dimensions, or 0 for no projection."
)
tf
.
flags
.
DEFINE_integer
(
"num_hidden_layers"
,
1
,
"Number of hidden layers in the network"
)
tf
.
flags
.
DEFINE_integer
(
"hidden_layer_num_units"
,
1000
,
"Number of units per hidden layer"
)
tf
.
flags
.
DEFINE_float
(
"default_gradient_l2norm_bound"
,
4.0
,
"norm clipping"
)
tf
.
flags
.
DEFINE_integer
(
"num_conv_layers"
,
0
,
"Number of convolutional layers to use."
)
tf
.
flags
.
DEFINE_string
(
"training_data_path"
,
"/tmp/mnist/mnist_train.tfrecord"
,
"Location of the training data."
)
tf
.
flags
.
DEFINE_string
(
"eval_data_path"
,
"/tmp/mnist/mnist_test.tfrecord"
,
"Location of the eval data."
)
tf
.
flags
.
DEFINE_integer
(
"eval_steps"
,
10
,
"Evaluate the model every eval_steps"
)
# Parameters for privacy spending. We allow linearly varying eps during
# training.
tf
.
flags
.
DEFINE_string
(
"accountant_type"
,
"Moments"
,
"Moments, Amortized."
)
# Flags that control privacy spending during training.
tf
.
flags
.
DEFINE_float
(
"eps"
,
1.0
,
"Start privacy spending for one epoch of training, "
"used if accountant_type is Amortized."
)
tf
.
flags
.
DEFINE_float
(
"end_eps"
,
1.0
,
"End privacy spending for one epoch of training, "
"used if accountant_type is Amortized."
)
tf
.
flags
.
DEFINE_float
(
"eps_saturate_epochs"
,
0
,
"Stop varying epsilon after eps_saturate_epochs. Set to "
"0 for constant eps of --eps. "
"Used if accountant_type is Amortized."
)
tf
.
flags
.
DEFINE_float
(
"delta"
,
1e-5
,
"Privacy spending for training. Constant through "
"training, used if accountant_type is Amortized."
)
tf
.
flags
.
DEFINE_float
(
"sigma"
,
4.0
,
"Noise sigma, used only if accountant_type is Moments"
)
# Flags that control privacy spending for the pca projection
# (only used if --projection_dimensions > 0).
tf
.
flags
.
DEFINE_float
(
"pca_eps"
,
0.5
,
"Privacy spending for PCA, used if accountant_type is "
"Amortized."
)
tf
.
flags
.
DEFINE_float
(
"pca_delta"
,
0.005
,
"Privacy spending for PCA, used if accountant_type is "
"Amortized."
)
tf
.
flags
.
DEFINE_float
(
"pca_sigma"
,
7.0
,
"Noise sigma for PCA, used if accountant_type is Moments"
)
tf
.
flags
.
DEFINE_string
(
"target_eps"
,
"0.125,0.25,0.5,1,2,4,8"
,
"Log the privacy loss for the target epsilon's. Only "
"used when accountant_type is Moments."
)
tf
.
flags
.
DEFINE_float
(
"target_delta"
,
1e-5
,
"Maximum delta for --terminate_based_on_privacy."
)
tf
.
flags
.
DEFINE_bool
(
"terminate_based_on_privacy"
,
False
,
"Stop training if privacy spent exceeds "
"(max(--target_eps), --target_delta), even "
"if --num_training_steps have not yet been completed."
)
tf
.
flags
.
DEFINE_string
(
"save_path"
,
"/tmp/mnist_dir"
,
"Directory for saving model outputs."
)
FLAGS
=
tf
.
flags
.
FLAGS
NUM_TRAINING_IMAGES
=
60000
NUM_TESTING_IMAGES
=
10000
IMAGE_SIZE
=
28
def
MnistInput
(
mnist_data_file
,
batch_size
,
randomize
):
"""Create operations to read the MNIST input file.
Args:
mnist_data_file: Path of a file containing the MNIST images to process.
batch_size: size of the mini batches to generate.
randomize: If true, randomize the dataset.
Returns:
images: A tensor with the formatted image data. shape [batch_size, 28*28]
labels: A tensor with the labels for each image. shape [batch_size]
"""
file_queue
=
tf
.
train
.
string_input_producer
([
mnist_data_file
])
reader
=
tf
.
TFRecordReader
()
_
,
value
=
reader
.
read
(
file_queue
)
example
=
tf
.
parse_single_example
(
value
,
features
=
{
"image/encoded"
:
tf
.
FixedLenFeature
(
shape
=
(),
dtype
=
tf
.
string
),
"image/class/label"
:
tf
.
FixedLenFeature
([
1
],
tf
.
int64
)})
image
=
tf
.
cast
(
tf
.
image
.
decode_png
(
example
[
"image/encoded"
],
channels
=
1
),
tf
.
float32
)
image
=
tf
.
reshape
(
image
,
[
IMAGE_SIZE
*
IMAGE_SIZE
])
image
/=
255
label
=
tf
.
cast
(
example
[
"image/class/label"
],
dtype
=
tf
.
int32
)
label
=
tf
.
reshape
(
label
,
[])
if
randomize
:
images
,
labels
=
tf
.
train
.
shuffle_batch
(
[
image
,
label
],
batch_size
=
batch_size
,
capacity
=
(
batch_size
*
100
),
min_after_dequeue
=
(
batch_size
*
10
))
else
:
images
,
labels
=
tf
.
train
.
batch
([
image
,
label
],
batch_size
=
batch_size
)
return
images
,
labels
def
Eval
(
mnist_data_file
,
network_parameters
,
num_testing_images
,
randomize
,
load_path
,
save_mistakes
=
False
):
"""Evaluate MNIST for a number of steps.
Args:
mnist_data_file: Path of a file containing the MNIST images to process.
network_parameters: parameters for defining and training the network.
num_testing_images: the number of images we will evaluate on.
randomize: if false, randomize; otherwise, read the testing images
sequentially.
load_path: path where to load trained parameters from.
save_mistakes: save the mistakes if True.
Returns:
The evaluation accuracy as a float.
"""
batch_size
=
100
# Like for training, we need a session for executing the TensorFlow graph.
with
tf
.
Graph
().
as_default
(),
tf
.
Session
()
as
sess
:
# Create the basic Mnist model.
images
,
labels
=
MnistInput
(
mnist_data_file
,
batch_size
,
randomize
)
logits
,
_
,
_
=
utils
.
BuildNetwork
(
images
,
network_parameters
)
softmax
=
tf
.
nn
.
softmax
(
logits
)
# Load the variables.
ckpt_state
=
tf
.
train
.
get_checkpoint_state
(
load_path
)
if
not
(
ckpt_state
and
ckpt_state
.
model_checkpoint_path
):
raise
ValueError
(
"No model checkpoint to eval at %s
\n
"
%
load_path
)
saver
=
tf
.
train
.
Saver
()
saver
.
restore
(
sess
,
ckpt_state
.
model_checkpoint_path
)
coord
=
tf
.
train
.
Coordinator
()
_
=
tf
.
train
.
start_queue_runners
(
sess
=
sess
,
coord
=
coord
)
total_examples
=
0
correct_predictions
=
0
image_index
=
0
mistakes
=
[]
for
_
in
xrange
((
num_testing_images
+
batch_size
-
1
)
//
batch_size
):
predictions
,
label_values
=
sess
.
run
([
softmax
,
labels
])
# Count how many were predicted correctly.
for
prediction
,
label_value
in
zip
(
predictions
,
label_values
):
total_examples
+=
1
if
np
.
argmax
(
prediction
)
==
label_value
:
correct_predictions
+=
1
elif
save_mistakes
:
mistakes
.
append
({
"index"
:
image_index
,
"label"
:
label_value
,
"pred"
:
np
.
argmax
(
prediction
)})
image_index
+=
1
return
(
correct_predictions
/
total_examples
,
mistakes
if
save_mistakes
else
None
)
def
Train
(
mnist_train_file
,
mnist_test_file
,
network_parameters
,
num_steps
,
save_path
,
eval_steps
=
0
):
"""Train MNIST for a number of steps.
Args:
mnist_train_file: path of MNIST train data file.
mnist_test_file: path of MNIST test data file.
network_parameters: parameters for defining and training the network.
num_steps: number of steps to run. Here steps = lots
save_path: path where to save trained parameters.
eval_steps: evaluate the model every eval_steps.
Returns:
the result after the final training step.
Raises:
ValueError: if the accountant_type is not supported.
"""
batch_size
=
FLAGS
.
batch_size
params
=
{
"accountant_type"
:
FLAGS
.
accountant_type
,
"task_id"
:
0
,
"batch_size"
:
FLAGS
.
batch_size
,
"projection_dimensions"
:
FLAGS
.
projection_dimensions
,
"default_gradient_l2norm_bound"
:
network_parameters
.
default_gradient_l2norm_bound
,
"num_hidden_layers"
:
FLAGS
.
num_hidden_layers
,
"hidden_layer_num_units"
:
FLAGS
.
hidden_layer_num_units
,
"num_examples"
:
NUM_TRAINING_IMAGES
,
"learning_rate"
:
FLAGS
.
lr
,
"end_learning_rate"
:
FLAGS
.
end_lr
,
"learning_rate_saturate_epochs"
:
FLAGS
.
lr_saturate_epochs
}
# Log different privacy parameters dependent on the accountant type.
if
FLAGS
.
accountant_type
==
"Amortized"
:
params
.
update
({
"flag_eps"
:
FLAGS
.
eps
,
"flag_delta"
:
FLAGS
.
delta
,
"flag_pca_eps"
:
FLAGS
.
pca_eps
,
"flag_pca_delta"
:
FLAGS
.
pca_delta
,
})
elif
FLAGS
.
accountant_type
==
"Moments"
:
params
.
update
({
"sigma"
:
FLAGS
.
sigma
,
"pca_sigma"
:
FLAGS
.
pca_sigma
,
})
with
tf
.
Graph
().
as_default
(),
tf
.
Session
()
as
sess
,
tf
.
device
(
'/cpu:0'
):
# Create the basic Mnist model.
images
,
labels
=
MnistInput
(
mnist_train_file
,
batch_size
,
FLAGS
.
randomize
)
logits
,
projection
,
training_params
=
utils
.
BuildNetwork
(
images
,
network_parameters
)
cost
=
tf
.
nn
.
softmax_cross_entropy_with_logits
(
logits
,
tf
.
one_hot
(
labels
,
10
))
# The actual cost is the average across the examples.
cost
=
tf
.
reduce_sum
(
cost
,
[
0
])
/
batch_size
if
FLAGS
.
accountant_type
==
"Amortized"
:
priv_accountant
=
accountant
.
AmortizedAccountant
(
NUM_TRAINING_IMAGES
)
sigma
=
None
pca_sigma
=
None
with_privacy
=
FLAGS
.
eps
>
0
elif
FLAGS
.
accountant_type
==
"Moments"
:
priv_accountant
=
accountant
.
GaussianMomentsAccountant
(
NUM_TRAINING_IMAGES
)
sigma
=
FLAGS
.
sigma
pca_sigma
=
FLAGS
.
pca_sigma
with_privacy
=
FLAGS
.
sigma
>
0
else
:
raise
ValueError
(
"Undefined accountant type, needs to be "
"Amortized or Moments, but got %s"
%
FLAGS
.
accountant
)
# Note: Here and below, we scale down the l2norm_bound by
# batch_size. This is because per_example_gradients computes the
# gradient of the minibatch loss with respect to each individual
# example, and the minibatch loss (for our model) is the *average*
# loss over examples in the minibatch. Hence, the scale of the
# per-example gradients goes like 1 / batch_size.
gaussian_sanitizer
=
sanitizer
.
AmortizedGaussianSanitizer
(
priv_accountant
,
[
network_parameters
.
default_gradient_l2norm_bound
/
batch_size
,
True
])
for
var
in
training_params
:
if
"gradient_l2norm_bound"
in
training_params
[
var
]:
l2bound
=
training_params
[
var
][
"gradient_l2norm_bound"
]
/
batch_size
gaussian_sanitizer
.
set_option
(
var
,
sanitizer
.
ClipOption
(
l2bound
,
True
))
lr
=
tf
.
placeholder
(
tf
.
float32
)
eps
=
tf
.
placeholder
(
tf
.
float32
)
delta
=
tf
.
placeholder
(
tf
.
float32
)
init_ops
=
[]
if
network_parameters
.
projection_type
==
"PCA"
:
with
tf
.
variable_scope
(
"pca"
):
# Compute differentially private PCA.
all_data
,
_
=
MnistInput
(
mnist_train_file
,
NUM_TRAINING_IMAGES
,
False
)
pca_projection
=
dp_pca
.
ComputeDPPrincipalProjection
(
all_data
,
network_parameters
.
projection_dimensions
,
gaussian_sanitizer
,
[
FLAGS
.
pca_eps
,
FLAGS
.
pca_delta
],
pca_sigma
)
assign_pca_proj
=
tf
.
assign
(
projection
,
pca_projection
)
init_ops
.
append
(
assign_pca_proj
)
# Add global_step
global_step
=
tf
.
Variable
(
0
,
dtype
=
tf
.
int32
,
trainable
=
False
,
name
=
"global_step"
)
if
with_privacy
:
gd_op
=
dp_optimizer
.
DPGradientDescentOptimizer
(
lr
,
[
eps
,
delta
],
gaussian_sanitizer
,
sigma
=
sigma
,
batches_per_lot
=
FLAGS
.
batches_per_lot
).
minimize
(
cost
,
global_step
=
global_step
)
else
:
gd_op
=
tf
.
train
.
GradientDescentOptimizer
(
lr
).
minimize
(
cost
)
saver
=
tf
.
train
.
Saver
()
coord
=
tf
.
train
.
Coordinator
()
_
=
tf
.
train
.
start_queue_runners
(
sess
=
sess
,
coord
=
coord
)
# We need to maintain the intialization sequence.
for
v
in
tf
.
trainable_variables
():
sess
.
run
(
tf
.
initialize_variables
([
v
]))
sess
.
run
(
tf
.
initialize_all_variables
())
sess
.
run
(
init_ops
)
results
=
[]
start_time
=
time
.
time
()
prev_time
=
start_time
filename
=
"results-0.json"
log_path
=
os
.
path
.
join
(
save_path
,
filename
)
target_eps
=
[
float
(
s
)
for
s
in
FLAGS
.
target_eps
.
split
(
","
)]
if
FLAGS
.
accountant_type
==
"Amortized"
:
# Only matters if --terminate_based_on_privacy is true.
target_eps
=
[
max
(
target_eps
)]
max_target_eps
=
max
(
target_eps
)
lot_size
=
FLAGS
.
batches_per_lot
*
FLAGS
.
batch_size
lots_per_epoch
=
NUM_TRAINING_IMAGES
/
lot_size
for
step
in
xrange
(
num_steps
):
epoch
=
step
/
lots_per_epoch
curr_lr
=
utils
.
VaryRate
(
FLAGS
.
lr
,
FLAGS
.
end_lr
,
FLAGS
.
lr_saturate_epochs
,
epoch
)
curr_eps
=
utils
.
VaryRate
(
FLAGS
.
eps
,
FLAGS
.
end_eps
,
FLAGS
.
eps_saturate_epochs
,
epoch
)
for
_
in
xrange
(
FLAGS
.
batches_per_lot
):
_
=
sess
.
run
(
[
gd_op
],
feed_dict
=
{
lr
:
curr_lr
,
eps
:
curr_eps
,
delta
:
FLAGS
.
delta
})
sys
.
stderr
.
write
(
"step: %d
\n
"
%
step
)
# See if we should stop training due to exceeded privacy budget:
should_terminate
=
False
terminate_spent_eps_delta
=
None
if
with_privacy
and
FLAGS
.
terminate_based_on_privacy
:
terminate_spent_eps_delta
=
priv_accountant
.
get_privacy_spent
(
sess
,
target_eps
=
[
max_target_eps
])[
0
]
# For the Moments accountant, we should always have
# spent_eps == max_target_eps.
if
(
terminate_spent_eps_delta
.
spent_delta
>
FLAGS
.
target_delta
or
terminate_spent_eps_delta
.
spent_eps
>
max_target_eps
):
should_terminate
=
True
if
(
eval_steps
>
0
and
(
step
+
1
)
%
eval_steps
==
0
)
or
should_terminate
:
if
with_privacy
:
spent_eps_deltas
=
priv_accountant
.
get_privacy_spent
(
sess
,
target_eps
=
target_eps
)
else
:
spent_eps_deltas
=
[
accountant
.
EpsDelta
(
0
,
0
)]
for
spent_eps
,
spent_delta
in
spent_eps_deltas
:
sys
.
stderr
.
write
(
"spent privacy: eps %.4f delta %.5g
\n
"
%
(
spent_eps
,
spent_delta
))
saver
.
save
(
sess
,
save_path
=
save_path
+
"/ckpt"
)
train_accuracy
,
_
=
Eval
(
mnist_train_file
,
network_parameters
,
num_testing_images
=
NUM_TESTING_IMAGES
,
randomize
=
True
,
load_path
=
save_path
)
sys
.
stderr
.
write
(
"train_accuracy: %.2f
\n
"
%
train_accuracy
)
test_accuracy
,
mistakes
=
Eval
(
mnist_test_file
,
network_parameters
,
num_testing_images
=
NUM_TESTING_IMAGES
,
randomize
=
False
,
load_path
=
save_path
,
save_mistakes
=
FLAGS
.
save_mistakes
)
sys
.
stderr
.
write
(
"eval_accuracy: %.2f
\n
"
%
test_accuracy
)
curr_time
=
time
.
time
()
elapsed_time
=
curr_time
-
prev_time
prev_time
=
curr_time
results
.
append
({
"step"
:
step
+
1
,
# Number of lots trained so far.
"elapsed_secs"
:
elapsed_time
,
"spent_eps_deltas"
:
spent_eps_deltas
,
"train_accuracy"
:
train_accuracy
,
"test_accuracy"
:
test_accuracy
,
"mistakes"
:
mistakes
})
loginfo
=
{
"elapsed_secs"
:
curr_time
-
start_time
,
"spent_eps_deltas"
:
spent_eps_deltas
,
"train_accuracy"
:
train_accuracy
,
"test_accuracy"
:
test_accuracy
,
"num_training_steps"
:
step
+
1
,
# Steps so far.
"mistakes"
:
mistakes
,
"result_series"
:
results
}
loginfo
.
update
(
params
)
if
log_path
:
with
tf
.
gfile
.
Open
(
log_path
,
"w"
)
as
f
:
json
.
dump
(
loginfo
,
f
,
indent
=
2
)
f
.
write
(
"
\n
"
)
f
.
close
()
if
should_terminate
:
break
def
main
(
_
):
network_parameters
=
utils
.
NetworkParameters
()
# If the ASCII proto isn't specified, then construct a config protobuf based
# on 3 flags.
network_parameters
.
input_size
=
IMAGE_SIZE
**
2
network_parameters
.
default_gradient_l2norm_bound
=
(
FLAGS
.
default_gradient_l2norm_bound
)
if
FLAGS
.
projection_dimensions
>
0
and
FLAGS
.
num_conv_layers
>
0
:
raise
ValueError
(
"Currently you can't do PCA and have convolutions"
"at the same time. Pick one"
)
# could add support for PCA after convolutions.
# Currently BuildNetwork can build the network with conv followed by
# projection, but the PCA training works on data, rather than data run
# through a few layers. Will need to init the convs before running the
# PCA, and need to change the PCA subroutine to take a network and perhaps
# allow for batched inputs, to handle larger datasets.
if
FLAGS
.
num_conv_layers
>
0
:
conv
=
utils
.
ConvParameters
()
conv
.
name
=
"conv1"
conv
.
in_channels
=
1
conv
.
out_channels
=
128
conv
.
num_outputs
=
128
*
14
*
14
network_parameters
.
conv_parameters
.
append
(
conv
)
# defaults for the rest: 5x5,stride 1, relu, maxpool 2x2,stride 2.
# insize 28x28, bias, stddev 0.1, non-trainable.
if
FLAGS
.
num_conv_layers
>
1
:
conv
=
network_parameters
.
ConvParameters
()
conv
.
name
=
"conv2"
conv
.
in_channels
=
128
conv
.
out_channels
=
128
conv
.
num_outputs
=
128
*
7
*
7
conv
.
in_size
=
14
# defaults for the rest: 5x5,stride 1, relu, maxpool 2x2,stride 2.
# bias, stddev 0.1, non-trainable.
network_parameters
.
conv_parameters
.
append
(
conv
)
if
FLAGS
.
num_conv_layers
>
2
:
raise
ValueError
(
"Currently --num_conv_layers must be 0,1 or 2."
"Manually create a network_parameters proto for more."
)
if
FLAGS
.
projection_dimensions
>
0
:
network_parameters
.
projection_type
=
"PCA"
network_parameters
.
projection_dimensions
=
FLAGS
.
projection_dimensions
for
i
in
xrange
(
FLAGS
.
num_hidden_layers
):
hidden
=
utils
.
LayerParameters
()
hidden
.
name
=
"hidden%d"
%
i
hidden
.
num_units
=
FLAGS
.
hidden_layer_num_units
hidden
.
relu
=
True
hidden
.
with_bias
=
False
hidden
.
trainable
=
not
FLAGS
.
freeze_bottom_layers
network_parameters
.
layer_parameters
.
append
(
hidden
)
logits
=
utils
.
LayerParameters
()
logits
.
name
=
"logits"
logits
.
num_units
=
10
logits
.
relu
=
False
logits
.
with_bias
=
False
network_parameters
.
layer_parameters
.
append
(
logits
)
Train
(
FLAGS
.
training_data_path
,
FLAGS
.
eval_data_path
,
network_parameters
,
FLAGS
.
num_training_steps
,
FLAGS
.
save_path
,
eval_steps
=
FLAGS
.
eval_steps
)
if
__name__
==
"__main__"
:
tf
.
app
.
run
()
differential_privacy/dp_optimizer/BUILD
0 → 100644
View file @
b4cb2454
package
(
default_visibility
=
[
":internal"
])
licenses
([
"notice"
])
# Apache 2.0
exports_files
([
"LICENSE"
])
package_group
(
name
=
"internal"
,
packages
=
[
"//differential_privacy/..."
,
],
)
py_library
(
name
=
"utils"
,
srcs
=
[
"utils.py"
,
],
deps
=
[
],
)
py_library
(
name
=
"dp_pca"
,
srcs
=
[
"dp_pca.py"
,
],
deps
=
[
],
)
py_library
(
name
=
"accountant"
,
srcs
=
[
"accountant.py"
,
],
deps
=
[
":utils"
,
],
)
py_library
(
name
=
"dp_optimizer"
,
srcs
=
[
"dp_optimizer.py"
,
"sanitizer.py"
,
],
deps
=
[
":accountant"
,
":utils"
,
"//differential_privacy/per_example_gradients"
,
],
)
differential_privacy/dp_optimizer/accountant.py
0 → 100644
View file @
b4cb2454
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines Accountant class for keeping track of privacy spending.
A privacy accountant keeps track of privacy spendings. It has methods
accumulate_privacy_spending and get_privacy_spent. Here we only define
AmortizedAccountant which tracks the privacy spending in the amortized
way. It uses privacy amplication via sampling to compute the privacy
spending for each batch and strong composition (specialized for Gaussian
noise) for accumulate the privacy spending.
"""
from
__future__
import
division
import
abc
import
collections
import
math
import
sys
import
numpy
import
tensorflow
as
tf
from
differential_privacy.dp_optimizer
import
utils
EpsDelta
=
collections
.
namedtuple
(
"EpsDelta"
,
[
"spent_eps"
,
"spent_delta"
])
# TODO(liqzhang) To ensure the same API for AmortizedAccountant and
# MomentsAccountant, we pass the union of arguments to both, so we
# have unused_sigma for AmortizedAccountant and unused_eps_delta for
# MomentsAccountant. Consider to revise the API to avoid the unused
# arguments. It would be good to use @abc.abstractmethod, etc, to
# define the common interface as a base class.
class
AmortizedAccountant
(
object
):
"""Keep track of privacy spending in an amortized way.
AmortizedAccountant accumulates the privacy spending by assuming
all the examples are processed uniformly at random so the spending is
amortized among all the examples. And we assume that we use Gaussian noise
so the accumulation is on eps^2 and delta, using advanced composition.
"""
def
__init__
(
self
,
total_examples
):
"""Initialization. Currently only support amortized tracking.
Args:
total_examples: total number of examples.
"""
assert
total_examples
>
0
self
.
_total_examples
=
total_examples
self
.
_eps_squared_sum
=
tf
.
Variable
(
tf
.
zeros
([
1
]),
trainable
=
False
,
name
=
"eps_squared_sum"
)
self
.
_delta_sum
=
tf
.
Variable
(
tf
.
zeros
([
1
]),
trainable
=
False
,
name
=
"delta_sum"
)
def
accumulate_privacy_spending
(
self
,
eps_delta
,
unused_sigma
,
num_examples
):
"""Accumulate the privacy spending.
Currently only support approximate privacy. Here we assume we use Gaussian
noise on randomly sampled batch so we get better composition: 1. the per
batch privacy is computed using privacy amplication via sampling bound;
2. the composition is done using the composition with Gaussian noise.
TODO(liqzhang) Add a link to a document that describes the bounds used.
Args:
eps_delta: EpsDelta pair which can be tensors.
unused_sigma: the noise sigma. Unused for this accountant.
num_examples: the number of examples involved.
Returns:
a TensorFlow operation for updating the privacy spending.
"""
eps
,
delta
=
eps_delta
with
tf
.
control_dependencies
(
[
tf
.
Assert
(
tf
.
greater
(
delta
,
0
),
[
"delta needs to be greater than 0"
])]):
amortize_ratio
=
(
tf
.
cast
(
num_examples
,
tf
.
float32
)
*
1.0
/
self
.
_total_examples
)
# Use privacy amplification via sampling bound.
# See Lemma 2.2 in http://arxiv.org/pdf/1405.7085v2.pdf
# TODO(liqzhang) Add a link to a document with formal statement
# and proof.
amortize_eps
=
tf
.
reshape
(
tf
.
log
(
1.0
+
amortize_ratio
*
(
tf
.
exp
(
eps
)
-
1.0
)),
[
1
])
amortize_delta
=
tf
.
reshape
(
amortize_ratio
*
delta
,
[
1
])
return
tf
.
group
(
*
[
tf
.
assign_add
(
self
.
_eps_squared_sum
,
tf
.
square
(
amortize_eps
)),
tf
.
assign_add
(
self
.
_delta_sum
,
amortize_delta
)])
def
get_privacy_spent
(
self
,
sess
,
target_eps
=
None
):
"""Report the spending so far.
Args:
sess: the session to run the tensor.
target_eps: the target epsilon. Unused.
Returns:
the list containing a single EpsDelta, with values as Python floats (as
opposed to numpy.float64). This is to be consistent with
MomentAccountant which can return a list of (eps, delta) pair.
"""
# pylint: disable=unused-argument
unused_target_eps
=
target_eps
eps_squared_sum
,
delta_sum
=
sess
.
run
([
self
.
_eps_squared_sum
,
self
.
_delta_sum
])
return
[
EpsDelta
(
math
.
sqrt
(
eps_squared_sum
),
float
(
delta_sum
))]
class
MomentsAccountant
(
object
):
"""Privacy accountant which keeps track of moments of privacy loss.
Note: The constructor of this class creates tf.Variables that must
be initialized with tf.initialize_all_variables() or similar calls.
MomentsAccountant accumulates the high moments of the privacy loss. It
requires a method for computing differenital moments of the noise (See
below for the definition). So every specific accountant should subclass
this class by implementing _differential_moments method.
Denote by X_i the random variable of privacy loss at the i-th step.
Consider two databases D, D' which differ by one item. X_i takes value
log Pr[M(D')==x]/Pr[M(D)==x] with probability Pr[M(D)==x].
In MomentsAccountant, we keep track of y_i(L) = log E[exp(L X_i)] for some
large enough L. To compute the final privacy spending, we apply Chernoff
bound (assuming the random noise added at each step is independent) to
bound the total privacy loss Z = sum X_i as follows:
Pr[Z > e] = Pr[exp(L Z) > exp(L e)]
< E[exp(L Z)] / exp(L e)
= Prod_i E[exp(L X_i)] / exp(L e)
= exp(sum_i log E[exp(L X_i)]) / exp(L e)
= exp(sum_i y_i(L) - L e)
Hence the mechanism is (e, d)-differentially private for
d = exp(sum_i y_i(L) - L e).
We require d < 1, i.e. e > sum_i y_i(L) / L. We maintain y_i(L) for several
L to compute the best d for any give e (normally should be the lowest L
such that 2 * sum_i y_i(L) / L < e.
We further assume that at each step, the mechanism operates on a random
sample with sampling probability q = batch_size / total_examples. Then
E[exp(L X)] = E[(Pr[M(D)==x / Pr[M(D')==x])^L]
By distinguishign two cases of wether D < D' or D' < D, we have
that
E[exp(L X)] <= max (I1, I2)
where
I1 = (1-q) E ((1-q) + q P(X+1) / P(X))^L + q E ((1-q) + q P(X) / P(X-1))^L
I2 = E (P(X) / ((1-q) + q P(X+1)))^L
In order to compute I1 and I2, one can consider to
1. use an asymptotic bound, which recovers the advance composition theorem;
2. use the closed formula (like GaussianMomentsAccountant);
3. use numerical integration or random sample estimation.
Dependent on the distribution, we can often obtain a tigher estimation on
the moments and hence a more accurate estimation of the privacy loss than
obtained using generic composition theorems.
"""
__metaclass__
=
abc
.
ABCMeta
def
__init__
(
self
,
total_examples
,
moment_orders
=
32
):
"""Initialize a MomentsAccountant.
Args:
total_examples: total number of examples.
moment_orders: the order of moments to keep.
"""
assert
total_examples
>
0
self
.
_total_examples
=
total_examples
self
.
_moment_orders
=
(
moment_orders
if
isinstance
(
moment_orders
,
(
list
,
tuple
))
else
range
(
1
,
moment_orders
+
1
))
self
.
_max_moment_order
=
max
(
self
.
_moment_orders
)
assert
self
.
_max_moment_order
<
100
,
"The moment order is too large."
self
.
_log_moments
=
[
tf
.
Variable
(
numpy
.
float64
(
0.0
),
trainable
=
False
,
name
=
(
"log_moments-%d"
%
moment_order
))
for
moment_order
in
self
.
_moment_orders
]
@
abc
.
abstractmethod
def
_compute_log_moment
(
self
,
sigma
,
q
,
moment_order
):
"""Compute high moment of privacy loss.
Args:
sigma: the noise sigma, in the multiples of the sensitivity.
q: the sampling ratio.
moment_order: the order of moment.
Returns:
log E[exp(moment_order * X)]
"""
pass
def
accumulate_privacy_spending
(
self
,
unused_eps_delta
,
sigma
,
num_examples
):
"""Accumulate privacy spending.
In particular, accounts for privacy spending when we assume there
are num_examples, and we are releasing the vector
(sum_{i=1}^{num_examples} x_i) + Normal(0, stddev=l2norm_bound*sigma)
where l2norm_bound is the maximum l2_norm of each example x_i, and
the num_examples have been randomly selected out of a pool of
self.total_examples.
Args:
unused_eps_delta: EpsDelta pair which can be tensors. Unused
in this accountant.
sigma: the noise sigma, in the multiples of the sensitivity (that is,
if the l2norm sensitivity is k, then the caller must have added
Gaussian noise with stddev=k*sigma to the result of the query).
num_examples: the number of examples involved.
Returns:
a TensorFlow operation for updating the privacy spending.
"""
q
=
tf
.
cast
(
num_examples
,
tf
.
float64
)
*
1.0
/
self
.
_total_examples
moments_accum_ops
=
[]
for
i
in
range
(
len
(
self
.
_log_moments
)):
moment
=
self
.
_compute_log_moment
(
sigma
,
q
,
self
.
_moment_orders
[
i
])
moments_accum_ops
.
append
(
tf
.
assign_add
(
self
.
_log_moments
[
i
],
moment
))
return
tf
.
group
(
*
moments_accum_ops
)
def
_compute_delta
(
self
,
log_moments
,
eps
):
"""Compute delta for given log_moments and eps.
Args:
log_moments: the log moments of privacy loss, in the form of pairs
of (moment_order, log_moment)
eps: the target epsilon.
Returns:
delta
"""
min_delta
=
1.0
for
moment_order
,
log_moment
in
log_moments
:
if
math
.
isinf
(
log_moment
)
or
math
.
isnan
(
log_moment
):
sys
.
stderr
.
write
(
"The %d-th order is inf or Nan
\n
"
%
moment_order
)
continue
if
log_moment
<
moment_order
*
eps
:
min_delta
=
min
(
min_delta
,
math
.
exp
(
log_moment
-
moment_order
*
eps
))
return
min_delta
def
_compute_eps
(
self
,
log_moments
,
delta
):
min_eps
=
float
(
"inf"
)
for
moment_order
,
log_moment
in
log_moments
:
if
math
.
isinf
(
log_moment
)
or
math
.
isnan
(
log_moment
):
sys
.
stderr
.
write
(
"The %d-th order is inf or Nan
\n
"
%
moment_order
)
continue
min_eps
=
min
(
min_eps
,
(
log_moment
-
math
.
log
(
delta
))
/
moment_order
)
return
min_eps
def
get_privacy_spent
(
self
,
sess
,
target_eps
=
None
,
target_deltas
=
None
):
"""Compute privacy spending in (e, d)-DP form for a single or list of eps.
Args:
sess: the session to run the tensor.
target_eps: a list of target epsilon's for which we would like to
compute corresponding delta value.
target_deltas: a list of target deltas for which we would like to
compute the corresponding eps value. Caller must specify
either target_eps or target_delta.
Returns:
A list of EpsDelta pairs.
"""
assert
(
target_eps
is
None
)
^
(
target_deltas
is
None
)
eps_deltas
=
[]
log_moments
=
sess
.
run
(
self
.
_log_moments
)
log_moments_with_order
=
zip
(
self
.
_moment_orders
,
log_moments
)
if
target_eps
is
not
None
:
for
eps
in
target_eps
:
eps_deltas
.
append
(
EpsDelta
(
eps
,
self
.
_compute_delta
(
log_moments_with_order
,
eps
)))
else
:
assert
target_deltas
for
delta
in
target_deltas
:
eps_deltas
.
append
(
EpsDelta
(
self
.
_compute_eps
(
log_moments_with_order
,
delta
),
delta
))
return
eps_deltas
class
GaussianMomentsAccountant
(
MomentsAccountant
):
"""MomentsAccountant which assumes Gaussian noise.
GaussianMomentsAccountant assumes the noise added is centered Gaussian
noise N(0, sigma^2 I). In this case, we can compute the differential moments
accurately using a formula.
For asymptotic bound, for Gaussian noise with variance sigma^2, we can show
for L < sigma^2, q L < sigma,
log E[exp(L X)] = O(q^2 L^2 / sigma^2).
Using this we derive that for training T epoches, with batch ratio q,
the Gaussian mechanism with variance sigma^2 (with q < 1/sigma) is (e, d)
private for d = exp(T/q q^2 L^2 / sigma^2 - L e). Setting L = sigma^2,
Tq = e/2, the mechanism is (e, exp(-e sigma^2/2))-DP. Equivalently, the
mechanism is (e, d)-DP if sigma = sqrt{2 log(1/d)}/e, q < 1/sigma,
and T < e/(2q). This bound is better than the bound obtained using general
composition theorems, by an Omega(sqrt{log k}) factor on epsilon, if we run
k steps. Since we use direct estimate, the obtained privacy bound has tight
constant.
For GaussianMomentAccountant, it suffices to compute I1, as I1 >= I2
(TODO(liqzhang): make sure this is true.), which reduce to computing
E(P(x+s)/P(x+s-1) - 1)^i for s = 0 and 1.
"""
def
__init__
(
self
,
total_examples
,
moment_orders
=
32
):
"""Initialization.
Args:
total_examples: total number of examples.
moment_orders: the order of moments to keep.
"""
super
(
self
.
__class__
,
self
).
__init__
(
total_examples
,
moment_orders
)
self
.
_binomial_table
=
utils
.
GenerateBinomialTable
(
self
.
_max_moment_order
)
def
_differential_moments
(
self
,
sigma
,
s
,
t
):
"""Compute 0 to t-th differential moments for Gaussian variable.
E[(P(x+s)/P(x+s-1)-1)^t]
= sum_{i=0}^t (t choose i) (-1)^{t-i} E[(P(x+s)/P(x+s-1))^i]
= sum_{i=0}^t (t choose i) (-1)^{t-i} E[exp(-i*(2*x+2*s-1)/(2*sigma^2))]
= sum_{i=0}^t (t choose i) (-1)^{t-i} exp(i(i+1-2*s)/(2 sigma^2))
Args:
sigma: the noise sigma, in the multiples of the sensitivity.
s: the shift.
t: 0 to t-th moment.
Returns:
0 to t-th moment as a tensor of shape [t+1].
"""
assert
t
<=
self
.
_max_moment_order
,
(
"The order of %d is out "
"of the upper bound %d."
%
(
t
,
self
.
_max_moment_order
))
binomial
=
tf
.
slice
(
self
.
_binomial_table
,
[
0
,
0
],
[
t
+
1
,
t
+
1
])
signs
=
numpy
.
zeros
((
t
+
1
,
t
+
1
),
dtype
=
numpy
.
float64
)
for
i
in
range
(
t
+
1
):
for
j
in
range
(
t
+
1
):
signs
[
i
,
j
]
=
1.0
-
2
*
((
i
-
j
)
%
2
)
exponents
=
tf
.
constant
([
j
*
(
j
+
1.0
-
2.0
*
s
)
/
(
2.0
*
sigma
*
sigma
)
for
j
in
range
(
t
+
1
)],
dtype
=
tf
.
float64
)
# x[i, j] = binomial[i, j] * signs[i, j] = (i choose j) * (-1)^{i-j}
x
=
tf
.
mul
(
binomial
,
signs
)
# y[i, j] = x[i, j] * exp(exponents[j])
# = (i choose j) * (-1)^{i-j} * exp(j(j-1)/(2 sigma^2))
# Note: this computation is done by broadcasting pointwise multiplication
# between [t+1, t+1] tensor and [t+1] tensor.
y
=
tf
.
mul
(
x
,
tf
.
exp
(
exponents
))
# z[i] = sum_j y[i, j]
# = sum_j (i choose j) * (-1)^{i-j} * exp(j(j-1)/(2 sigma^2))
z
=
tf
.
reduce_sum
(
y
,
1
)
return
z
def
_compute_log_moment
(
self
,
sigma
,
q
,
moment_order
):
"""Compute high moment of privacy loss.
Args:
sigma: the noise sigma, in the multiples of the sensitivity.
q: the sampling ratio.
moment_order: the order of moment.
Returns:
log E[exp(moment_order * X)]
"""
assert
moment_order
<=
self
.
_max_moment_order
,
(
"The order of %d is out "
"of the upper bound %d."
%
(
moment_order
,
self
.
_max_moment_order
))
binomial_table
=
tf
.
slice
(
self
.
_binomial_table
,
[
moment_order
,
0
],
[
1
,
moment_order
+
1
])
# qs = [1 q q^2 ... q^L] = exp([0 1 2 ... L] * log(q))
qs
=
tf
.
exp
(
tf
.
constant
([
i
*
1.0
for
i
in
range
(
moment_order
+
1
)],
dtype
=
tf
.
float64
)
*
tf
.
cast
(
tf
.
log
(
q
),
dtype
=
tf
.
float64
))
moments0
=
self
.
_differential_moments
(
sigma
,
0.0
,
moment_order
)
term0
=
tf
.
reduce_sum
(
binomial_table
*
qs
*
moments0
)
moments1
=
self
.
_differential_moments
(
sigma
,
1.0
,
moment_order
)
term1
=
tf
.
reduce_sum
(
binomial_table
*
qs
*
moments1
)
return
tf
.
squeeze
(
tf
.
log
(
tf
.
cast
(
q
*
term0
+
(
1.0
-
q
)
*
term1
,
tf
.
float64
)))
class
DummyAccountant
(
object
):
"""An accountant that does no accounting."""
def
accumulate_privacy_spending
(
self
,
*
unused_args
):
return
tf
.
no_op
()
def
get_privacy_spent
(
self
,
unused_sess
,
**
unused_kwargs
):
return
[
EpsDelta
(
numpy
.
inf
,
1.0
)]
differential_privacy/dp_optimizer/dp_optimizer.py
0 → 100644
View file @
b4cb2454
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Differentially private optimizers.
"""
from
__future__
import
division
import
tensorflow
as
tf
from
differential_privacy.dp_optimizer
import
utils
from
differential_privacy.per_example_gradients
import
per_example_gradients
class
DPGradientDescentOptimizer
(
tf
.
train
.
GradientDescentOptimizer
):
"""Differentially private gradient descent optimizer.
"""
def
__init__
(
self
,
learning_rate
,
eps_delta
,
sanitizer
,
sigma
=
None
,
use_locking
=
False
,
name
=
"DPGradientDescent"
,
batches_per_lot
=
1
):
"""Construct a differentially private gradient descent optimizer.
The optimizer uses fixed privacy budget for each batch of training.
Args:
learning_rate: for GradientDescentOptimizer.
eps_delta: EpsDelta pair for each epoch.
sanitizer: for sanitizing the graident.
sigma: noise sigma. If None, use eps_delta pair to compute sigma;
otherwise use supplied sigma directly.
use_locking: use locking.
name: name for the object.
batches_per_lot: Number of batches in a lot.
"""
super
(
DPGradientDescentOptimizer
,
self
).
__init__
(
learning_rate
,
use_locking
,
name
)
# Also, if needed, define the gradient accumulators
self
.
_batches_per_lot
=
batches_per_lot
self
.
_grad_accum_dict
=
{}
if
batches_per_lot
>
1
:
self
.
_batch_count
=
tf
.
Variable
(
1
,
dtype
=
tf
.
int32
,
trainable
=
False
,
name
=
"batch_count"
)
var_list
=
tf
.
trainable_variables
()
with
tf
.
variable_scope
(
"grad_acc_for"
):
for
var
in
var_list
:
v_grad_accum
=
tf
.
Variable
(
tf
.
zeros_like
(
var
),
trainable
=
False
,
name
=
utils
.
GetTensorOpName
(
var
))
self
.
_grad_accum_dict
[
var
.
name
]
=
v_grad_accum
self
.
_eps_delta
=
eps_delta
self
.
_sanitizer
=
sanitizer
self
.
_sigma
=
sigma
def
compute_sanitized_gradients
(
self
,
loss
,
var_list
=
None
,
add_noise
=
True
):
"""Compute the sanitized gradients.
Args:
loss: the loss tensor.
var_list: the optional variables.
add_noise: if true, then add noise. Always clip.
Returns:
a pair of (list of sanitized gradients) and privacy spending accumulation
operations.
Raises:
TypeError: if var_list contains non-variable.
"""
self
.
_assert_valid_dtypes
([
loss
])
xs
=
[
tf
.
convert_to_tensor
(
x
)
for
x
in
var_list
]
px_grads
=
per_example_gradients
.
PerExampleGradients
(
loss
,
xs
)
sanitized_grads
=
[]
for
px_grad
,
v
in
zip
(
px_grads
,
var_list
):
tensor_name
=
utils
.
GetTensorOpName
(
v
)
sanitized_grad
=
self
.
_sanitizer
.
sanitize
(
px_grad
,
self
.
_eps_delta
,
sigma
=
self
.
_sigma
,
tensor_name
=
tensor_name
,
add_noise
=
add_noise
,
num_examples
=
self
.
_batches_per_lot
*
tf
.
slice
(
tf
.
shape
(
px_grad
),
[
0
],
[
1
]))
sanitized_grads
.
append
(
sanitized_grad
)
return
sanitized_grads
def
minimize
(
self
,
loss
,
global_step
=
None
,
var_list
=
None
,
name
=
None
):
"""Minimize using sanitized gradients.
This gets a var_list which is the list of trainable variables.
For each var in var_list, we defined a grad_accumulator variable
during init. When batches_per_lot > 1, we accumulate the gradient
update in those. At the end of each lot, we apply the update back to
the variable. This has the effect that for each lot we compute
gradients at the point at the beginning of the lot, and then apply one
update at the end of the lot. In other words, semantically, we are doing
SGD with one lot being the equivalent of one usual batch of size
batch_size * batches_per_lot.
This allows us to simulate larger batches than our memory size would permit.
The lr and the num_steps are in the lot world.
Args:
loss: the loss tensor.
global_step: the optional global step.
var_list: the optional variables.
name: the optional name.
Returns:
the operation that runs one step of DP gradient descent.
"""
# First validate the var_list
if
var_list
is
None
:
var_list
=
tf
.
trainable_variables
()
for
var
in
var_list
:
if
not
isinstance
(
var
,
tf
.
Variable
):
raise
TypeError
(
"Argument is not a variable.Variable: %s"
%
var
)
# Modification: apply gradient once every batches_per_lot many steps.
# This may lead to smaller error
if
self
.
_batches_per_lot
==
1
:
sanitized_grads
=
self
.
compute_sanitized_gradients
(
loss
,
var_list
=
var_list
)
grads_and_vars
=
zip
(
sanitized_grads
,
var_list
)
self
.
_assert_valid_dtypes
([
v
for
g
,
v
in
grads_and_vars
if
g
is
not
None
])
apply_grads
=
self
.
apply_gradients
(
grads_and_vars
,
global_step
=
global_step
,
name
=
name
)
return
apply_grads
# Condition for deciding whether to accumulate the gradient
# or actually apply it.
# we use a private self_batch_count to keep track of number of batches.
# global step will count number of lots processed.
update_cond
=
tf
.
equal
(
tf
.
constant
(
0
),
tf
.
mod
(
self
.
_batch_count
,
tf
.
constant
(
self
.
_batches_per_lot
)))
# Things to do for batches other than last of the lot.
# Add non-noisy clipped grads to shadow variables.
def
non_last_in_lot_op
(
loss
,
var_list
):
"""Ops to do for a typical batch.
For a batch that is not the last one in the lot, we simply compute the
sanitized gradients and apply them to the grad_acc variables.
Args:
loss: loss function tensor
var_list: list of variables
Returns:
A tensorflow op to do the updates to the gradient accumulators
"""
sanitized_grads
=
self
.
compute_sanitized_gradients
(
loss
,
var_list
=
var_list
,
add_noise
=
False
)
update_ops_list
=
[]
for
var
,
grad
in
zip
(
var_list
,
sanitized_grads
):
grad_acc_v
=
self
.
_grad_accum_dict
[
var
.
name
]
update_ops_list
.
append
(
grad_acc_v
.
assign_add
(
grad
))
update_ops_list
.
append
(
self
.
_batch_count
.
assign_add
(
1
))
return
tf
.
group
(
*
update_ops_list
)
# Things to do for last batch of a lot.
# Add noisy clipped grads to accumulator.
# Apply accumulated grads to vars.
def
last_in_lot_op
(
loss
,
var_list
,
global_step
):
"""Ops to do for last batch in a lot.
For the last batch in the lot, we first add the sanitized gradients to
the gradient acc variables, and then apply these
values over to the original variables (via an apply gradient)
Args:
loss: loss function tensor
var_list: list of variables
global_step: optional global step to be passed to apply_gradients
Returns:
A tensorflow op to push updates from shadow vars to real vars.
"""
# We add noise in the last lot. This is why we need this code snippet
# that looks almost identical to the non_last_op case here.
sanitized_grads
=
self
.
compute_sanitized_gradients
(
loss
,
var_list
=
var_list
,
add_noise
=
True
)
normalized_grads
=
[]
for
var
,
grad
in
zip
(
var_list
,
sanitized_grads
):
grad_acc_v
=
self
.
_grad_accum_dict
[
var
.
name
]
# To handle the lr difference per lot vs per batch, we divide the
# update by number of batches per lot.
normalized_grad
=
tf
.
div
(
grad_acc_v
.
assign_add
(
grad
),
tf
.
to_float
(
self
.
_batches_per_lot
))
normalized_grads
.
append
(
normalized_grad
)
with
tf
.
control_dependencies
(
normalized_grads
):
grads_and_vars
=
zip
(
normalized_grads
,
var_list
)
self
.
_assert_valid_dtypes
(
[
v
for
g
,
v
in
grads_and_vars
if
g
is
not
None
])
apply_san_grads
=
self
.
apply_gradients
(
grads_and_vars
,
global_step
=
global_step
,
name
=
"apply_grads"
)
# Now reset the accumulators to zero
resets_list
=
[]
with
tf
.
control_dependencies
([
apply_san_grads
]):
for
_
,
acc
in
self
.
_grad_accum_dict
.
items
():
reset
=
tf
.
assign
(
acc
,
tf
.
zeros_like
(
acc
))
resets_list
.
append
(
reset
)
resets_list
.
append
(
self
.
_batch_count
.
assign_add
(
1
))
last_step_update
=
tf
.
group
(
*
([
apply_san_grads
]
+
resets_list
))
return
last_step_update
# pylint: disable=g-long-lambda
update_op
=
tf
.
cond
(
update_cond
,
lambda
:
last_in_lot_op
(
loss
,
var_list
,
global_step
),
lambda
:
non_last_in_lot_op
(
loss
,
var_list
))
return
tf
.
group
(
update_op
)
differential_privacy/dp_optimizer/dp_pca.py
0 → 100644
View file @
b4cb2454
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Differentially private optimizers.
"""
import
tensorflow
as
tf
from
differential_privacy.dp_optimizer
import
sanitizer
as
san
def
ComputeDPPrincipalProjection
(
data
,
projection_dims
,
sanitizer
,
eps_delta
,
sigma
):
"""Compute differentially private projection.
Args:
data: the input data, each row is a data vector.
projection_dims: the projection dimension.
sanitizer: the sanitizer used for acheiving privacy.
eps_delta: (eps, delta) pair.
sigma: if not None, use noise sigma; otherwise compute it using
eps_delta pair.
Returns:
A projection matrix with projection_dims columns.
"""
eps
,
delta
=
eps_delta
# Normalize each row.
normalized_data
=
tf
.
nn
.
l2_normalize
(
data
,
1
)
covar
=
tf
.
matmul
(
tf
.
transpose
(
normalized_data
),
normalized_data
)
saved_shape
=
tf
.
shape
(
covar
)
num_examples
=
tf
.
slice
(
tf
.
shape
(
data
),
[
0
],
[
1
])
if
eps
>
0
:
# Since the data is already normalized, there is no need to clip
# the covariance matrix.
assert
delta
>
0
saned_covar
=
sanitizer
.
sanitize
(
tf
.
reshape
(
covar
,
[
1
,
-
1
]),
eps_delta
,
sigma
=
sigma
,
option
=
san
.
ClipOption
(
1.0
,
False
),
num_examples
=
num_examples
)
saned_covar
=
tf
.
reshape
(
saned_covar
,
saved_shape
)
# Symmetrize saned_covar. This also reduces the noise variance.
saned_covar
=
0.5
*
(
saned_covar
+
tf
.
transpose
(
saned_covar
))
else
:
saned_covar
=
covar
# Compute the eigen decomposition of the covariance matrix, and
# return the top projection_dims eigen vectors, represented as columns of
# the projection matrix.
eigvals
,
eigvecs
=
tf
.
self_adjoint_eig
(
saned_covar
)
_
,
topk_indices
=
tf
.
nn
.
top_k
(
eigvals
,
projection_dims
)
topk_indices
=
tf
.
reshape
(
topk_indices
,
[
projection_dims
])
# Gather and return the corresponding eigenvectors.
return
tf
.
transpose
(
tf
.
gather
(
tf
.
transpose
(
eigvecs
),
topk_indices
))
differential_privacy/dp_optimizer/sanitizer.py
0 → 100644
View file @
b4cb2454
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines Sanitizer class for sanitizing tensors.
A sanitizer first limits the sensitivity of a tensor and then adds noise
to the tensor. The parameters are determined by the privacy_spending and the
other parameters. It also uses an accountant to keep track of the privacy
spending.
"""
from
__future__
import
division
import
collections
import
tensorflow
as
tf
from
differential_privacy.dp_optimizer
import
utils
ClipOption
=
collections
.
namedtuple
(
"ClipOption"
,
[
"l2norm_bound"
,
"clip"
])
class
AmortizedGaussianSanitizer
(
object
):
"""Sanitizer with Gaussian noise and amoritzed privacy spending accounting.
This sanitizes a tensor by first clipping the tensor, summing the tensor
and then adding appropriate amount of noise. It also uses an amortized
accountant to keep track of privacy spending.
"""
def
__init__
(
self
,
accountant
,
default_option
):
"""Construct an AmortizedGaussianSanitizer.
Args:
accountant: the privacy accountant. Expect an amortized one.
default_option: the default ClipOptoin.
"""
self
.
_accountant
=
accountant
self
.
_default_option
=
default_option
self
.
_options
=
{}
def
set_option
(
self
,
tensor_name
,
option
):
"""Set options for an individual tensor.
Args:
tensor_name: the name of the tensor.
option: clip option.
"""
self
.
_options
[
tensor_name
]
=
option
def
sanitize
(
self
,
x
,
eps_delta
,
sigma
=
None
,
option
=
ClipOption
(
None
,
None
),
tensor_name
=
None
,
num_examples
=
None
,
add_noise
=
True
):
"""Sanitize the given tensor.
This santize a given tensor by first applying l2 norm clipping and then
adding Gaussian noise. It calls the privacy accountant for updating the
privacy spending.
Args:
x: the tensor to sanitize.
eps_delta: a pair of eps, delta for (eps,delta)-DP. Use it to
compute sigma if sigma is None.
sigma: if sigma is not None, use sigma.
option: a ClipOption which, if supplied, used for
clipping and adding noise.
tensor_name: the name of the tensor.
num_examples: if None, use the number of "rows" of x.
add_noise: if True, then add noise, else just clip.
Returns:
a pair of sanitized tensor and the operation to accumulate privacy
spending.
"""
if
sigma
is
None
:
# pylint: disable=unpacking-non-sequence
eps
,
delta
=
eps_delta
with
tf
.
control_dependencies
(
[
tf
.
Assert
(
tf
.
greater
(
eps
,
0
),
[
"eps needs to be greater than 0"
]),
tf
.
Assert
(
tf
.
greater
(
delta
,
0
),
[
"delta needs to be greater than 0"
])]):
# The following formula is taken from
# Dwork and Roth, The Algorithmic Foundations of Differential
# Privacy, Appendix A.
# http://www.cis.upenn.edu/~aaroth/Papers/privacybook.pdf
sigma
=
tf
.
sqrt
(
2.0
*
tf
.
log
(
1.25
/
delta
))
/
eps
l2norm_bound
,
clip
=
option
if
l2norm_bound
is
None
:
l2norm_bound
,
clip
=
self
.
_default_option
if
((
tensor_name
is
not
None
)
and
(
tensor_name
in
self
.
_options
)):
l2norm_bound
,
clip
=
self
.
_options
[
tensor_name
]
if
clip
:
x
=
utils
.
BatchClipByL2norm
(
x
,
l2norm_bound
)
if
add_noise
:
if
num_examples
is
None
:
num_examples
=
tf
.
slice
(
tf
.
shape
(
x
),
[
0
],
[
1
])
privacy_accum_op
=
self
.
_accountant
.
accumulate_privacy_spending
(
eps_delta
,
sigma
,
num_examples
)
with
tf
.
control_dependencies
([
privacy_accum_op
]):
saned_x
=
utils
.
AddGaussianNoise
(
tf
.
reduce_sum
(
x
,
0
),
sigma
*
l2norm_bound
)
else
:
saned_x
=
tf
.
reduce_sum
(
x
,
0
)
return
saned_x
differential_privacy/dp_optimizer/utils.py
0 → 100644
View file @
b4cb2454
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils for building and training NN models.
"""
from
__future__
import
division
import
math
import
numpy
import
tensorflow
as
tf
class
LayerParameters
(
object
):
"""class that defines a non-conv layer."""
def
__init__
(
self
):
self
.
name
=
""
self
.
num_units
=
0
self
.
_with_bias
=
False
self
.
relu
=
False
self
.
gradient_l2norm_bound
=
0.0
self
.
bias_gradient_l2norm_bound
=
0.0
self
.
trainable
=
True
self
.
weight_decay
=
0.0
class
ConvParameters
(
object
):
"""class that defines a conv layer."""
def
__init__
(
self
):
self
.
patch_size
=
5
self
.
stride
=
1
self
.
in_channels
=
1
self
.
out_channels
=
0
self
.
with_bias
=
True
self
.
relu
=
True
self
.
max_pool
=
True
self
.
max_pool_size
=
2
self
.
max_pool_stride
=
2
self
.
trainable
=
False
self
.
in_size
=
28
self
.
name
=
""
self
.
num_outputs
=
0
self
.
bias_stddev
=
0.1
# Parameters for a layered neural network.
class
NetworkParameters
(
object
):
"""class that define the overall model structure."""
def
__init__
(
self
):
self
.
input_size
=
0
self
.
projection_type
=
'NONE'
# NONE, RANDOM, PCA
self
.
projection_dimensions
=
0
self
.
default_gradient_l2norm_bound
=
0.0
self
.
layer_parameters
=
[]
# List of LayerParameters
self
.
conv_parameters
=
[]
# List of ConvParameters
def
GetTensorOpName
(
x
):
"""Get the name of the op that created a tensor.
Useful for naming related tensors, as ':' in name field of op is not permitted
Args:
x: the input tensor.
Returns:
the name of the op.
"""
t
=
x
.
name
.
rsplit
(
":"
,
1
)
if
len
(
t
)
==
1
:
return
x
.
name
else
:
return
t
[
0
]
def
BuildNetwork
(
inputs
,
network_parameters
):
"""Build a network using the given parameters.
Args:
inputs: a Tensor of floats containing the input data.
network_parameters: NetworkParameters object
that describes the parameters for the network.
Returns:
output, training_parameters: where the outputs (a tensor) is the output
of the network, and training_parameters (a dictionary that maps the
name of each variable to a dictionary of parameters) is the parameters
used during training.
"""
training_parameters
=
{}
num_inputs
=
network_parameters
.
input_size
outputs
=
inputs
projection
=
None
# First apply convolutions, if needed
for
conv_param
in
network_parameters
.
conv_parameters
:
outputs
=
tf
.
reshape
(
outputs
,
[
-
1
,
conv_param
.
in_size
,
conv_param
.
in_size
,
conv_param
.
in_channels
])
conv_weights_name
=
"%s_conv_weight"
%
(
conv_param
.
name
)
conv_bias_name
=
"%s_conv_bias"
%
(
conv_param
.
name
)
conv_std_dev
=
1.0
/
(
conv_param
.
patch_size
*
math
.
sqrt
(
conv_param
.
in_channels
))
conv_weights
=
tf
.
Variable
(
tf
.
truncated_normal
([
conv_param
.
patch_size
,
conv_param
.
patch_size
,
conv_param
.
in_channels
,
conv_param
.
out_channels
],
stddev
=
conv_std_dev
),
trainable
=
conv_param
.
trainable
,
name
=
conv_weights_name
)
conv_bias
=
tf
.
Variable
(
tf
.
truncated_normal
([
conv_param
.
out_channels
],
stddev
=
conv_param
.
bias_stddev
),
trainable
=
conv_param
.
trainable
,
name
=
conv_bias_name
)
training_parameters
[
conv_weights_name
]
=
{}
training_parameters
[
conv_bias_name
]
=
{}
conv
=
tf
.
nn
.
conv2d
(
outputs
,
conv_weights
,
strides
=
[
1
,
conv_param
.
stride
,
conv_param
.
stride
,
1
],
padding
=
"SAME"
)
relud
=
tf
.
nn
.
relu
(
conv
+
conv_bias
)
mpd
=
tf
.
nn
.
max_pool
(
relud
,
ksize
=
[
1
,
conv_param
.
max_pool_size
,
conv_param
.
max_pool_size
,
1
],
strides
=
[
1
,
conv_param
.
max_pool_stride
,
conv_param
.
max_pool_stride
,
1
],
padding
=
"SAME"
)
outputs
=
mpd
num_inputs
=
conv_param
.
num_outputs
# this should equal
# in_size * in_size * out_channels / (stride * max_pool_stride)
# once all the convs are done, reshape to make it flat
outputs
=
tf
.
reshape
(
outputs
,
[
-
1
,
num_inputs
])
# Now project, if needed
if
network_parameters
.
projection_type
is
not
"NONE"
:
projection
=
tf
.
Variable
(
tf
.
truncated_normal
(
[
num_inputs
,
network_parameters
.
projection_dimensions
],
stddev
=
1.0
/
math
.
sqrt
(
num_inputs
)),
trainable
=
False
,
name
=
"projection"
)
num_inputs
=
network_parameters
.
projection_dimensions
outputs
=
tf
.
matmul
(
outputs
,
projection
)
# Now apply any other layers
for
layer_parameters
in
network_parameters
.
layer_parameters
:
num_units
=
layer_parameters
.
num_units
hidden_weights_name
=
"%s_weight"
%
(
layer_parameters
.
name
)
hidden_weights
=
tf
.
Variable
(
tf
.
truncated_normal
([
num_inputs
,
num_units
],
stddev
=
1.0
/
math
.
sqrt
(
num_inputs
)),
name
=
hidden_weights_name
,
trainable
=
layer_parameters
.
trainable
)
training_parameters
[
hidden_weights_name
]
=
{}
if
layer_parameters
.
gradient_l2norm_bound
:
training_parameters
[
hidden_weights_name
][
"gradient_l2norm_bound"
]
=
(
layer_parameters
.
gradient_l2norm_bound
)
if
layer_parameters
.
weight_decay
:
training_parameters
[
hidden_weights_name
][
"weight_decay"
]
=
(
layer_parameters
.
weight_decay
)
outputs
=
tf
.
matmul
(
outputs
,
hidden_weights
)
if
layer_parameters
.
with_bias
:
hidden_biases_name
=
"%s_bias"
%
(
layer_parameters
.
name
)
hidden_biases
=
tf
.
Variable
(
tf
.
zeros
([
num_units
]),
name
=
hidden_biases_name
)
training_parameters
[
hidden_biases_name
]
=
{}
if
layer_parameters
.
bias_gradient_l2norm_bound
:
training_parameters
[
hidden_biases_name
][
"bias_gradient_l2norm_bound"
]
=
(
layer_parameters
.
bias_gradient_l2norm_bound
)
outputs
+=
hidden_biases
if
layer_parameters
.
relu
:
outputs
=
tf
.
nn
.
relu
(
outputs
)
# num_inputs for the next layer is num_units in the current layer.
num_inputs
=
num_units
return
outputs
,
projection
,
training_parameters
def
VaryRate
(
start
,
end
,
saturate_epochs
,
epoch
):
"""Compute a linearly varying number.
Decrease linearly from start to end until epoch saturate_epochs.
Args:
start: the initial number.
end: the end number.
saturate_epochs: after this we do not reduce the number; if less than
or equal to zero, just return start.
epoch: the current learning epoch.
Returns:
the caculated number.
"""
if
saturate_epochs
<=
0
:
return
start
step
=
(
start
-
end
)
/
(
saturate_epochs
-
1
)
if
epoch
<
saturate_epochs
:
return
start
-
step
*
epoch
else
:
return
end
def
BatchClipByL2norm
(
t
,
upper_bound
,
name
=
None
):
"""Clip an array of tensors by L2 norm.
Shrink each dimension-0 slice of tensor (for matrix it is each row) such
that the l2 norm is at most upper_bound. Here we clip each row as it
corresponds to each example in the batch.
Args:
t: the input tensor.
upper_bound: the upperbound of the L2 norm.
name: optional name.
Returns:
the clipped tensor.
"""
assert
upper_bound
>
0
with
tf
.
op_scope
([
t
,
upper_bound
],
name
,
"batch_clip_by_l2norm"
)
as
name
:
saved_shape
=
tf
.
shape
(
t
)
batch_size
=
tf
.
slice
(
saved_shape
,
[
0
],
[
1
])
t2
=
tf
.
reshape
(
t
,
tf
.
concat
(
0
,
[
batch_size
,
[
-
1
]]))
upper_bound_inv
=
tf
.
fill
(
tf
.
slice
(
saved_shape
,
[
0
],
[
1
]),
tf
.
constant
(
1.0
/
upper_bound
))
# Add a small number to avoid divide by 0
l2norm_inv
=
tf
.
rsqrt
(
tf
.
reduce_sum
(
t2
*
t2
,
[
1
])
+
0.000001
)
scale
=
tf
.
minimum
(
l2norm_inv
,
upper_bound_inv
)
*
upper_bound
clipped_t
=
tf
.
matmul
(
tf
.
diag
(
scale
),
t2
)
clipped_t
=
tf
.
reshape
(
clipped_t
,
saved_shape
,
name
=
name
)
return
clipped_t
def
SoftThreshold
(
t
,
threshold_ratio
,
name
=
None
):
"""Soft-threshold a tensor by the mean value.
Softthreshold each dimension-0 vector (for matrix it is each column) by
the mean of absolute value multiplied by the threshold_ratio factor. Here
we soft threshold each column as it corresponds to each unit in a layer.
Args:
t: the input tensor.
threshold_ratio: the threshold ratio.
name: the optional name for the returned tensor.
Returns:
the thresholded tensor, where each entry is soft-thresholded by
threshold_ratio times the mean of the aboslute value of each column.
"""
assert
threshold_ratio
>=
0
with
tf
.
op_scope
([
t
,
threshold_ratio
],
name
,
"soft_thresholding"
)
as
name
:
saved_shape
=
tf
.
shape
(
t
)
t2
=
tf
.
reshape
(
t
,
tf
.
concat
(
0
,
[
tf
.
slice
(
saved_shape
,
[
0
],
[
1
]),
-
1
]))
t_abs
=
tf
.
abs
(
t2
)
t_x
=
tf
.
sign
(
t2
)
*
tf
.
nn
.
relu
(
t_abs
-
(
tf
.
reduce_mean
(
t_abs
,
[
0
],
keep_dims
=
True
)
*
threshold_ratio
))
return
tf
.
reshape
(
t_x
,
saved_shape
,
name
=
name
)
def
AddGaussianNoise
(
t
,
sigma
,
name
=
None
):
"""Add i.i.d. Gaussian noise (0, sigma^2) to every entry of t.
Args:
t: the input tensor.
sigma: the stddev of the Gaussian noise.
name: optional name.
Returns:
the noisy tensor.
"""
with
tf
.
op_scope
([
t
,
sigma
],
name
,
"add_gaussian_noise"
)
as
name
:
noisy_t
=
t
+
tf
.
random_normal
(
tf
.
shape
(
t
),
stddev
=
sigma
)
return
noisy_t
def
GenerateBinomialTable
(
m
):
"""Generate binomial table.
Args:
m: the size of the table.
Returns:
A two dimensional array T where T[i][j] = (i choose j),
for 0<= i, j <=m.
"""
table
=
numpy
.
zeros
((
m
+
1
,
m
+
1
),
dtype
=
numpy
.
float64
)
for
i
in
range
(
m
+
1
):
table
[
i
,
0
]
=
1
for
i
in
range
(
1
,
m
+
1
):
for
j
in
range
(
1
,
m
+
1
):
v
=
table
[
i
-
1
,
j
]
+
table
[
i
-
1
,
j
-
1
]
assert
not
math
.
isnan
(
v
)
and
not
math
.
isinf
(
v
)
table
[
i
,
j
]
=
v
return
tf
.
convert_to_tensor
(
table
)
differential_privacy/per_example_gradients/BUILD
0 → 100644
View file @
b4cb2454
package
(
default_visibility
=
[
":internal"
])
licenses
([
"notice"
])
# Apache 2.0
exports_files
([
"LICENSE"
])
package_group
(
name
=
"internal"
,
packages
=
[
"//differential_privacy/..."
,
],
)
py_library
(
name
=
"per_example_gradients"
,
srcs
=
[
"per_example_gradients.py"
,
],
deps
=
[
],
)
differential_privacy/per_example_gradients/__init__.py
0 → 100644
View file @
b4cb2454
differential_privacy/per_example_gradients/per_example_gradients.py
0 → 100644
View file @
b4cb2454
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Per-example gradients for selected ops."""
import
collections
import
tensorflow
as
tf
OrderedDict
=
collections
.
OrderedDict
def
_ListUnion
(
list_1
,
list_2
):
"""Returns the union of two lists.
Python sets can have a non-deterministic iteration order. In some
contexts, this could lead to TensorFlow producing two different
programs when the same Python script is run twice. In these contexts
we use lists instead of sets.
This function is not designed to be especially fast and should only
be used with small lists.
Args:
list_1: A list
list_2: Another list
Returns:
A new list containing one copy of each unique element of list_1 and
list_2. Uniqueness is determined by "x in union" logic; e.g. two
string of that value appearing in the union.
Raises:
TypeError: The arguments are not lists.
"""
if
not
(
isinstance
(
list_1
,
list
)
and
isinstance
(
list_2
,
list
)):
raise
TypeError
(
"Arguments must be lists."
)
union
=
[]
for
x
in
list_1
+
list_2
:
if
x
not
in
union
:
union
.
append
(
x
)
return
union
def
Interface
(
ys
,
xs
):
"""
Returns a dict mapping each element of xs to any of its consumers that are
indirectly consumed by ys.
Args:
ys: The outputs
xs: The inputs
Returns:
out: Dict mapping each member x of `xs` to a list of all Tensors that are
direct consumers of x and are eventually consumed by a member of
`ys`.
"""
if
isinstance
(
ys
,
(
list
,
tuple
)):
queue
=
list
(
ys
)
else
:
queue
=
[
ys
]
out
=
OrderedDict
()
if
isinstance
(
xs
,
(
list
,
tuple
)):
for
x
in
xs
:
out
[
x
]
=
[]
else
:
out
[
xs
]
=
[]
done
=
set
()
while
queue
:
y
=
queue
.
pop
()
if
y
in
done
:
continue
done
=
done
.
union
(
set
([
y
]))
for
x
in
y
.
op
.
inputs
:
if
x
in
out
:
out
[
x
].
append
(
y
)
else
:
assert
id
(
x
)
not
in
[
id
(
foo
)
for
foo
in
out
]
queue
.
extend
(
y
.
op
.
inputs
)
return
out
class
PXGRegistry
(
object
):
"""Per-Example Gradient registry.
Maps names of ops to per-example gradient rules for those ops.
These rules are only needed for ops that directly touch values that
are shared between examples. For most machine learning applications,
this means only ops that directly operate on the parameters.
See http://arxiv.org/abs/1510.01799 for more information, and please
consider citing that tech report if you use this function in published
research.
"""
def
__init__
(
self
):
self
.
d
=
OrderedDict
()
def
__call__
(
self
,
op
,
colocate_gradients_with_ops
=
False
,
gate_gradients
=
False
):
if
op
.
node_def
.
op
not
in
self
.
d
:
raise
NotImplementedError
(
"No per-example gradient rule registered "
"for "
+
op
.
node_def
.
op
+
" in pxg_registry."
)
return
self
.
d
[
op
.
node_def
.
op
](
op
,
colocate_gradients_with_ops
,
gate_gradients
)
def
Register
(
self
,
op_name
,
pxg_class
):
"""Associates `op_name` key with `pxg_class` value.
Registers `pxg_class` as the class that will be called to perform
per-example differentiation through ops with `op_name`.
Args:
op_name: String op name.
pxg_class: An instance of any class with the same signature as MatMulPXG.
"""
self
.
d
[
op_name
]
=
pxg_class
pxg_registry
=
PXGRegistry
()
class
MatMulPXG
(
object
):
"""Per-example gradient rule for MatMul op.
"""
def
__init__
(
self
,
op
,
colocate_gradients_with_ops
=
False
,
gate_gradients
=
False
):
"""Construct an instance of the rule for `op`.
Args:
op: The Operation to differentiate through.
colocate_gradients_with_ops: currently unsupported
gate_gradients: currently unsupported
"""
assert
op
.
node_def
.
op
==
"MatMul"
self
.
op
=
op
self
.
colocate_gradients_with_ops
=
colocate_gradients_with_ops
self
.
gate_gradients
=
gate_gradients
def
__call__
(
self
,
x
,
z_grads
):
"""Build the graph for the per-example gradient through the op.
Assumes that the MatMul was called with a design matrix with examples
in rows as the first argument and parameters as the second argument.
Args:
x: The Tensor to differentiate with respect to. This tensor must
represent the weights.
z_grads: The list of gradients on the output of the op.
Returns:
x_grads: A Tensor containing the gradient with respect to `x` for
each example. This is a 3-D tensor, with the first axis corresponding
to examples and the remaining axes matching the shape of x.
"""
idx
=
list
(
self
.
op
.
inputs
).
index
(
x
)
assert
idx
!=
-
1
assert
len
(
z_grads
)
==
len
(
self
.
op
.
outputs
)
assert
idx
==
1
# We expect weights to be arg 1
# We don't expect anyone to per-example differentiate with repsect
# to anything other than the weights.
x
,
w
=
self
.
op
.
inputs
z_grads
,
=
z_grads
x_expanded
=
tf
.
expand_dims
(
x
,
2
)
z_grads_expanded
=
tf
.
expand_dims
(
z_grads
,
1
)
return
tf
.
mul
(
x_expanded
,
z_grads_expanded
)
pxg_registry
.
Register
(
"MatMul"
,
MatMulPXG
)
class
Conv2DPXG
(
object
):
"""Per-example gradient rule of Conv2d op.
Same interface as MatMulPXG.
"""
def
__init__
(
self
,
op
,
colocate_gradients_with_ops
=
False
,
gate_gradients
=
False
):
assert
op
.
node_def
.
op
==
"Conv2D"
self
.
op
=
op
self
.
colocate_gradients_with_ops
=
colocate_gradients_with_ops
self
.
gate_gradients
=
gate_gradients
def
_PxConv2DBuilder
(
self
,
input_
,
w
,
strides
,
padding
):
"""conv2d run separately per example, to help compute per-example gradients.
Args:
input_: tensor containing a minibatch of images / feature maps.
Shape [batch_size, rows, columns, channels]
w: convolution kernels. Shape
[kernel rows, kernel columns, input channels, output channels]
strides: passed through to regular conv_2d
padding: passed through to regular conv_2d
Returns:
conv: the output of the convolution.
single tensor, same as what regular conv_2d does
w_px: a list of batch_size copies of w. each copy was used
for the corresponding example in the minibatch.
calling tf.gradients on the copy gives the gradient for just
that example.
"""
input_shape
=
[
int
(
e
)
for
e
in
input_
.
get_shape
()]
batch_size
=
input_shape
[
0
]
input_px
=
[
tf
.
slice
(
input_
,
[
example
]
+
[
0
]
*
3
,
[
1
]
+
input_shape
[
1
:])
for
example
in
xrange
(
batch_size
)]
for
input_x
in
input_px
:
assert
int
(
input_x
.
get_shape
()[
0
])
==
1
w_px
=
[
tf
.
identity
(
w
)
for
example
in
xrange
(
batch_size
)]
conv_px
=
[
tf
.
nn
.
conv2d
(
input_x
,
w_x
,
strides
=
strides
,
padding
=
padding
)
for
input_x
,
w_x
in
zip
(
input_px
,
w_px
)]
for
conv_x
in
conv_px
:
num_x
=
int
(
conv_x
.
get_shape
()[
0
])
assert
num_x
==
1
,
num_x
assert
len
(
conv_px
)
==
batch_size
conv
=
tf
.
concat
(
0
,
conv_px
)
assert
int
(
conv
.
get_shape
()[
0
])
==
batch_size
return
conv
,
w_px
def
__call__
(
self
,
w
,
z_grads
):
idx
=
list
(
self
.
op
.
inputs
).
index
(
w
)
# Make sure that `op` was actually applied to `w`
assert
idx
!=
-
1
assert
len
(
z_grads
)
==
len
(
self
.
op
.
outputs
)
# The following assert may be removed when we are ready to use this
# for general purpose code.
# This assert is only expected to hold in the contex of our preliminary
# MNIST experiments.
assert
idx
==
1
# We expect convolution weights to be arg 1
images
,
filters
=
self
.
op
.
inputs
strides
=
self
.
op
.
get_attr
(
"strides"
)
padding
=
self
.
op
.
get_attr
(
"padding"
)
# Currently assuming that one specifies at most these four arguments and
# that all other arguments to conv2d are set to default.
conv
,
w_px
=
self
.
_PxConv2DBuilder
(
images
,
filters
,
strides
,
padding
)
z_grads
,
=
z_grads
gradients_list
=
tf
.
gradients
(
conv
,
w_px
,
z_grads
,
colocate_gradients_with_ops
=
self
.
colocate_gradients_with_ops
,
gate_gradients
=
self
.
gate_gradients
)
return
tf
.
pack
(
gradients_list
)
pxg_registry
.
Register
(
"Conv2D"
,
Conv2DPXG
)
class
AddPXG
(
object
):
"""Per-example gradient rule for Add op.
Same interface as MatMulPXG.
"""
def
__init__
(
self
,
op
,
colocate_gradients_with_ops
=
False
,
gate_gradients
=
False
):
assert
op
.
node_def
.
op
==
"Add"
self
.
op
=
op
self
.
colocate_gradients_with_ops
=
colocate_gradients_with_ops
self
.
gate_gradients
=
gate_gradients
def
__call__
(
self
,
x
,
z_grads
):
idx
=
list
(
self
.
op
.
inputs
).
index
(
x
)
# Make sure that `op` was actually applied to `x`
assert
idx
!=
-
1
assert
len
(
z_grads
)
==
len
(
self
.
op
.
outputs
)
# The following assert may be removed when we are ready to use this
# for general purpose code.
# This assert is only expected to hold in the contex of our preliminary
# MNIST experiments.
assert
idx
==
1
# We expect biases to be arg 1
# We don't expect anyone to per-example differentiate with respect
# to anything other than the biases.
x
,
b
=
self
.
op
.
inputs
z_grads
,
=
z_grads
return
z_grads
pxg_registry
.
Register
(
"Add"
,
AddPXG
)
def
PerExampleGradients
(
ys
,
xs
,
grad_ys
=
None
,
name
=
"gradients"
,
colocate_gradients_with_ops
=
False
,
gate_gradients
=
False
):
"""Symbolic differentiation, separately for each example.
Matches the interface of tf.gradients, but the return values each have an
additional axis corresponding to the examples.
Assumes that the cost in `ys` is additive across examples.
e.g., no batch normalization.
Individual rules for each op specify their own assumptions about how
examples are put into tensors.
"""
# Find the interface between the xs and the cost
for
x
in
xs
:
assert
isinstance
(
x
,
tf
.
Tensor
),
type
(
x
)
interface
=
Interface
(
ys
,
xs
)
merged_interface
=
[]
for
x
in
xs
:
merged_interface
=
_ListUnion
(
merged_interface
,
interface
[
x
])
# Differentiate with respect to the interface
interface_gradients
=
tf
.
gradients
(
ys
,
merged_interface
,
grad_ys
=
grad_ys
,
name
=
name
,
colocate_gradients_with_ops
=
colocate_gradients_with_ops
,
gate_gradients
=
gate_gradients
)
grad_dict
=
OrderedDict
(
zip
(
merged_interface
,
interface_gradients
))
# Build the per-example gradients with respect to the xs
if
colocate_gradients_with_ops
:
raise
NotImplementedError
(
"The per-example gradients are not yet "
"colocated with ops."
)
if
gate_gradients
:
raise
NotImplementedError
(
"The per-example gradients are not yet "
"gated."
)
out
=
[]
for
x
in
xs
:
zs
=
interface
[
x
]
ops
=
[]
for
z
in
zs
:
ops
=
_ListUnion
(
ops
,
[
z
.
op
])
if
len
(
ops
)
!=
1
:
raise
NotImplementedError
(
"Currently we only support the case "
"where each x is consumed by exactly "
"one op. but %s is consumed by %d ops."
%
(
x
.
name
,
len
(
ops
)))
op
=
ops
[
0
]
pxg_rule
=
pxg_registry
(
op
,
colocate_gradients_with_ops
,
gate_gradients
)
x_grad
=
pxg_rule
(
x
,
[
grad_dict
[
z
]
for
z
in
zs
])
out
.
append
(
x_grad
)
return
out
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment