Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
ed7c95fa
Commit
ed7c95fa
authored
May 01, 2017
by
David G. Andersen
Browse files
Open source release of adversarial crypto code corresponding
to Abadi & Andersen paper.
parent
6e934f5a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
330 additions
and
0 deletions
+330
-0
adversarial_crypto/README.md
adversarial_crypto/README.md
+56
-0
adversarial_crypto/train_eval.py
adversarial_crypto/train_eval.py
+274
-0
No files found.
adversarial_crypto/README.md
0 → 100644
View file @
ed7c95fa
# Learning to Protect Communications with Adversarial Neural Cryptography
This is a slightly-updated model used for the paper
[
"Learning to Protect Communications with Adversarial Neural
Cryptography"
](
https://arxiv.org/abs/1610.06918
)
.
> We ask whether neural networks can learn to use secret keys to protect
> information from other neural networks. Specifically, we focus on ensuring
> confidentiality properties in a multiagent system, and we specify those
> properties in terms of an adversary. Thus, a system may consist of neural
> networks named Alice and Bob, and we aim to limit what a third neural
> network named Eve learns from eavesdropping on the communication between
> Alice and Bob. We do not prescribe specific cryptographic algorithms to
> these neural networks; instead, we train end-to-end, adversarially.
> We demonstrate that the neural networks can learn how to perform forms of
> encryption and decryption, and also how to apply these operations
> selectively in order to meet confidentiality goals.
This code allows you to train an encoder/decoder/adversary triplet
and evaluate their effectiveness on randomly generated input and key
pairs.
## Prerequisites
The only software requirements for running the encoder and decoder is having
Tensorflow installed.
Requires Tensorflow r0.12 or later.
## Training and evaluating
After installing TensorFlow and ensuring that your paths are configured
appropriately:
python train_eval.py
This will begin training a fresh model. If and when the model becomes
sufficiently well-trained, it will reset the Eve model multiple times
and retrain it from scratch, outputting the accuracy thus obtained
in each run.
## Model differences from the paper
The model has been simplified slightly from the one described in
the paper - the convolutional layer width was reduced by a factor
of two. In the version in the paper, there was a nonlinear unit
after the fully-connected layer; that nonlinear has been removed
here. These changes improve the robustness of training. The
initializer for the convolution layers has switched to the
tf.contrib.layers default of xavier_initializer instead of
a simpler truncated_normal.
## Contact information
This model repository is maintained by David G. Andersen
(
[
dave-andersen
](
https://github.com/dave-andersen
)
).
adversarial_crypto/train_eval.py
0 → 100644
View file @
ed7c95fa
# Copyright 2016 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Adversarial training to learn trivial encryption functions,
from the paper "Learning to Protect Communications with
Adversarial Neural Cryptography", Abadi & Andersen, 2016.
https://arxiv.org/abs/1610.06918
This program creates and trains three neural networks,
termed Alice, Bob, and Eve. Alice takes inputs
in_m (message), in_k (key) and outputs 'ciphertext'.
Bob takes inputs in_k, ciphertext and tries to reconstruct
the message.
Eve is an adversarial network that takes input ciphertext
and also tries to reconstruct the message.
The main function attempts to train these networks and then
evaluates them, all on random plaintext and key values.
"""
# TensorFlow Python 3 compatibility
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
signal
import
sys
from
six.moves
import
xrange
# pylint: disable=redefined-builtin
import
tensorflow
as
tf
flags
=
tf
.
app
.
flags
flags
.
DEFINE_float
(
'learning_rate'
,
0.0008
,
'Constant learning rate'
)
flags
.
DEFINE_integer
(
'batch_size'
,
4096
,
'Batch size'
)
FLAGS
=
flags
.
FLAGS
# Input and output configuration.
TEXT_SIZE
=
16
KEY_SIZE
=
16
# Training parameters.
ITERS_PER_ACTOR
=
1
EVE_MULTIPLIER
=
2
# Train Eve 2x for every step of Alice/Bob
# Train until either max loops or Alice/Bob "good enough":
MAX_TRAINING_LOOPS
=
850000
BOB_LOSS_THRESH
=
0.02
# Exit when Bob loss < 0.02 and Eve > 7.7 bits
EVE_LOSS_THRESH
=
7.7
# Logging and evaluation.
PRINT_EVERY
=
200
# In training, log every 200 steps.
EVE_EXTRA_ROUNDS
=
2000
# At end, train eve a bit more.
RETRAIN_EVE_ITERS
=
10000
# Retrain eve up to ITERS*LOOPS times.
RETRAIN_EVE_LOOPS
=
25
# With an evaluation each loop
NUMBER_OF_EVE_RESETS
=
5
# And do this up to 5 times with a fresh eve.
# Use EVAL_BATCHES samples each time we check accuracy.
EVAL_BATCHES
=
1
def
batch_of_random_bools
(
batch_size
,
n
):
"""Return a batch of random "boolean" numbers.
Args:
batch_size: Batch size dimension of returned tensor.
n: number of entries per batch.
Returns:
A [batch_size, n] tensor of "boolean" numbers, where each number is
preresented as -1 or 1.
"""
as_int
=
tf
.
random_uniform
(
[
batch_size
,
n
],
minval
=
0
,
maxval
=
2
,
dtype
=
tf
.
int32
)
expanded_range
=
(
as_int
*
2
)
-
1
return
tf
.
cast
(
expanded_range
,
tf
.
float32
)
class
AdversarialCrypto
(
object
):
"""Primary model implementation class for Adversarial Neural Crypto.
This class contains the code for the model itself,
and when created, plumbs the pathways from Alice to Bob and
Eve, creates the optimizers and loss functions, etc.
Attributes:
eve_loss: Eve's loss function.
bob_loss: Bob's loss function. Different units from eve_loss.
eve_optimizer: A tf op that runs Eve's optimizer.
bob_optimizer: A tf op that runs Bob's optimizer.
bob_reconstruction_loss: Bob's message reconstruction loss,
which is comparable to eve_loss.
reset_eve_vars: Execute this op to completely reset Eve.
"""
def
get_message_and_key
(
self
):
"""Generate random pseudo-boolean key and message values."""
batch_size
=
tf
.
placeholder_with_default
(
FLAGS
.
batch_size
,
shape
=
[])
in_m
=
batch_of_random_bools
(
batch_size
,
TEXT_SIZE
)
in_k
=
batch_of_random_bools
(
batch_size
,
KEY_SIZE
)
return
in_m
,
in_k
def
model
(
self
,
collection
,
message
,
key
=
None
):
"""The model for Alice, Bob, and Eve. If key=None, the first FC layer
takes only the Key as inputs. Otherwise, it uses both the key
and the message.
Args:
collection: The graph keys collection to add new vars to.
message: The input message to process.
key: The input key (if any) to use.
"""
if
key
is
not
None
:
combined_message
=
tf
.
concat
(
1
,
[
message
,
key
])
else
:
combined_message
=
message
# Ensure that all variables created are in the specified collection.
with
tf
.
contrib
.
framework
.
arg_scope
(
[
tf
.
contrib
.
layers
.
fully_connected
,
tf
.
contrib
.
layers
.
convolution
],
variables_collections
=
[
collection
]):
fc
=
tf
.
contrib
.
layers
.
fully_connected
(
combined_message
,
TEXT_SIZE
+
KEY_SIZE
,
biases_initializer
=
tf
.
constant_initializer
(
0.0
),
activation_fn
=
None
)
# Perform a sequence of 1D convolutions (by expanding the message out to 2D
# and then squeezing it back down).
fc
=
tf
.
expand_dims
(
fc
,
2
)
# 2,1 -> 1,2
conv
=
tf
.
contrib
.
layers
.
convolution
(
fc
,
2
,
2
,
2
,
'SAME'
,
activation_fn
=
tf
.
nn
.
sigmoid
)
# 1,2 -> 1, 2
conv
=
tf
.
contrib
.
layers
.
convolution
(
conv
,
2
,
1
,
1
,
'SAME'
,
activation_fn
=
tf
.
nn
.
sigmoid
)
# 1,2 -> 1, 1
conv
=
tf
.
contrib
.
layers
.
convolution
(
conv
,
1
,
1
,
1
,
'SAME'
,
activation_fn
=
tf
.
nn
.
tanh
)
conv
=
tf
.
squeeze
(
conv
,
2
)
return
conv
def
__init__
(
self
):
in_m
,
in_k
=
self
.
get_message_and_key
()
encrypted
=
self
.
model
(
'alice'
,
in_m
,
in_k
)
decrypted
=
self
.
model
(
'bob'
,
encrypted
,
in_k
)
eve_out
=
self
.
model
(
'eve'
,
encrypted
,
None
)
self
.
reset_eve_vars
=
tf
.
group
(
*
[
w
.
initializer
for
w
in
tf
.
get_collection
(
'eve'
)])
optimizer
=
tf
.
train
.
AdamOptimizer
(
learning_rate
=
FLAGS
.
learning_rate
)
# Eve's goal is to decrypt the entire message:
eve_bits_wrong
=
tf
.
reduce_sum
(
tf
.
abs
((
eve_out
+
1.0
)
/
2.0
-
(
in_m
+
1.0
)
/
2.0
),
[
1
])
self
.
eve_loss
=
tf
.
reduce_sum
(
eve_bits_wrong
)
self
.
eve_optimizer
=
optimizer
.
minimize
(
self
.
eve_loss
,
var_list
=
tf
.
get_collection
(
'eve'
))
# Alice and Bob want to be accurate...
self
.
bob_bits_wrong
=
tf
.
reduce_sum
(
tf
.
abs
((
decrypted
+
1.0
)
/
2.0
-
(
in_m
+
1.0
)
/
2.0
),
[
1
])
# ... and to not let Eve do better than guessing.
self
.
bob_reconstruction_loss
=
tf
.
reduce_sum
(
self
.
bob_bits_wrong
)
bob_eve_error_deviation
=
tf
.
abs
(
float
(
TEXT_SIZE
)
/
2.0
-
eve_bits_wrong
)
# 7-9 bits wrong is OK too, so we squish the error function a bit.
# Without doing this, we often tend to hang out at 0.25 / 7.5 error,
# and it seems bad to have continued, high communication error.
bob_eve_loss
=
tf
.
reduce_sum
(
tf
.
square
(
bob_eve_error_deviation
)
/
(
TEXT_SIZE
/
2
)
**
2
)
# Rescale the losses to [0, 1] per example and combine.
self
.
bob_loss
=
(
self
.
bob_reconstruction_loss
/
TEXT_SIZE
+
bob_eve_loss
)
self
.
bob_optimizer
=
optimizer
.
minimize
(
self
.
bob_loss
,
var_list
=
(
tf
.
get_collection
(
'alice'
)
+
tf
.
get_collection
(
'bob'
)))
def
doeval
(
s
,
ac
,
n
,
itercount
):
"""Evaluate the current network on n batches of random examples.
Args:
s: The current TensorFlow session
ac: an instance of the AdversarialCrypto class
n: The number of iterations to run.
itercount: Iteration count label for logging.
Returns:
Bob and eve's loss, as a percent of bits incorrect.
"""
bob_loss_accum
=
0
eve_loss_accum
=
0
for
_
in
xrange
(
n
):
bl
,
el
=
s
.
run
([
ac
.
bob_reconstruction_loss
,
ac
.
eve_loss
])
bob_loss_accum
+=
bl
eve_loss_accum
+=
el
bob_loss_percent
=
bob_loss_accum
/
(
n
*
FLAGS
.
batch_size
)
eve_loss_percent
=
eve_loss_accum
/
(
n
*
FLAGS
.
batch_size
)
print
(
'%d %.2f %.2f'
%
(
itercount
,
bob_loss_percent
,
eve_loss_percent
))
sys
.
stdout
.
flush
()
return
bob_loss_percent
,
eve_loss_percent
def
train_until_thresh
(
s
,
ac
):
for
j
in
xrange
(
MAX_TRAINING_LOOPS
):
for
_
in
xrange
(
ITERS_PER_ACTOR
):
s
.
run
(
ac
.
bob_optimizer
)
for
_
in
xrange
(
ITERS_PER_ACTOR
*
EVE_MULTIPLIER
):
s
.
run
(
ac
.
eve_optimizer
)
if
j
%
PRINT_EVERY
==
0
:
bob_avg_loss
,
eve_avg_loss
=
doeval
(
s
,
ac
,
EVAL_BATCHES
,
j
)
if
(
bob_avg_loss
<
BOB_LOSS_THRESH
and
eve_avg_loss
>
EVE_LOSS_THRESH
):
print
(
'Target losses achieved.'
)
return
True
return
False
def
train_and_evaluate
():
"""Run the full training and evaluation loop."""
ac
=
AdversarialCrypto
()
init
=
tf
.
global_variables_initializer
()
with
tf
.
Session
()
as
s
:
s
.
run
(
init
)
print
(
'# Batch size: '
,
FLAGS
.
batch_size
)
print
(
'# Iter Bob_Recon_Error Eve_Recon_Error'
)
if
train_until_thresh
(
s
,
ac
):
for
_
in
xrange
(
EVE_EXTRA_ROUNDS
):
s
.
run
(
eve_optimizer
)
print
(
'Loss after eve extra training:'
)
doeval
(
s
,
ac
,
EVAL_BATCHES
*
2
,
0
)
for
_
in
xrange
(
NUMBER_OF_EVE_RESETS
):
print
(
'Resetting Eve'
)
s
.
run
(
reset_eve_vars
)
eve_counter
=
0
for
_
in
xrange
(
RETRAIN_EVE_LOOPS
):
for
_
in
xrange
(
RETRAIN_EVE_ITERS
):
eve_counter
+=
1
s
.
run
(
eve_optimizer
)
doeval
(
s
,
ac
,
EVAL_BATCHES
,
eve_counter
)
doeval
(
s
,
ac
,
EVAL_BATCHES
,
eve_counter
)
def
main
(
unused_argv
):
# Exit more quietly with Ctrl-C.
signal
.
signal
(
signal
.
SIGINT
,
signal
.
SIG_DFL
)
train_and_evaluate
()
if
__name__
==
'__main__'
:
tf
.
app
.
run
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment