Commit 9ee773e3 authored by Ilya Mironov's avatar Ilya Mironov
Browse files

Adding several plots for the slide deck.

parent 3b158095
......@@ -51,9 +51,11 @@ The output is written to the console.
* plot_partition.py — Script for producing partition.pdf, a detailed breakdown of privacy
costs for Confident-GNMax with smooth sensitivity analysis (takes ~50 hours).
* rdp_flow.py and plot_ls_q.py are currently not used.
* plots_for_slides.py — Script for producing several plots for the slide deck.
* download.py — Utility script for populating the data/ directory.
* plot_ls_q.py is not used.
All Python files take flags. Run script_name.py --help for help on flags.
......@@ -13,9 +13,7 @@
# limitations under the License.
# ==============================================================================
"""Plots two graphs illustrating cost of privacy per answered query.
PRESENTLY NOT USED.
"""Plots graphs for the slide deck.
A script in support of the PATE2 paper. The input is a file containing a numpy
array of votes, one query per row, one class per column. Ex:
......@@ -23,7 +21,7 @@ array of votes, one query per row, one class per column. Ex:
31, 16, ..., 0
...
0, 86, ..., 438
The output is written to a specified directory and consists of two pdf files.
The output graphs are visualized using the TkAgg backend.
"""
from __future__ import absolute_import
from __future__ import division
......@@ -35,10 +33,10 @@ import sys
sys.path.append('..') # Main modules reside in the parent directory.
from absl import app
from absl import flags
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt # pylint: disable=g-import-not-at-top
import numpy as np
......@@ -48,21 +46,30 @@ import random
plt.style.use('ggplot')
FLAGS = flags.FLAGS
flags.DEFINE_string('counts_file', '', 'Counts file.')
flags.DEFINE_string('counts_file', None, 'Counts file.')
flags.DEFINE_string('figures_dir', '', 'Path where figures are written to.')
flags.DEFINE_boolean('transparent', False, 'Set background to transparent.')
flags.mark_flag_as_required('counts_file')
def plot_rdp_curve_per_example(votes, sigmas):
orders = np.linspace(1., 100., endpoint=True, num=1000)
orders[0] = 1.001
def setup_plot():
fig, ax = plt.subplots()
fig.set_figheight(4.5)
fig.set_figwidth(4.7)
fig.patch.set_alpha(0)
if FLAGS.transparent:
fig.patch.set_alpha(0)
return fig, ax
for i in xrange(votes.shape[0]):
def plot_rdp_curve_per_example(votes, sigmas):
orders = np.linspace(1., 100., endpoint=True, num=1000)
orders[0] = 1.001
fig, ax = setup_plot()
for i in range(votes.shape[0]):
for sigma in sigmas:
logq = pate.compute_logq_gaussian(votes[i,], sigma)
rdp = pate.rdp_gaussian(logq, sigma, orders)
......@@ -95,26 +102,19 @@ def plot_rdp_curve_per_example(votes, sigmas):
def plot_rdp_of_sigma(v, order):
sigmas = np.linspace(1., 1000., endpoint=True, num=1000)
fig, ax = plt.subplots()
fig.set_figheight(4.5)
fig.set_figwidth(4.7)
fig, ax = setup_plot()
fig.patch.set_alpha(0)
y = np.zeros(len(sigmas))
for i, sigma in enumerate(sigmas):
logq = pate.compute_logq_gaussian(v, sigma)
y[i] = pate.rdp_gaussian(logq, sigma, order)
ax.plot(
sigmas,
y,
alpha=.8,
linewidth=5)
ax.plot(sigmas, y, alpha=.8, linewidth=5)
plt.xlim(xmin=1, xmax=1000)
plt.ylim(ymin=0)
#plt.yticks([0, .0004, .0008, .0012])
# plt.yticks([0, .0004, .0008, .0012])
ax.tick_params(labelleft='off')
plt.xlabel(r'Noise $\sigma$', fontsize=16)
plt.ylabel(r'RDP at order $\alpha={}$'.format(order), fontsize=16)
......@@ -125,7 +125,7 @@ def plot_rdp_of_sigma(v, order):
def compute_rdp_curve(votes, threshold, sigma1, sigma2, orders,
target_answered):
target_answered):
rdp_cum = np.zeros(len(orders))
answered = 0
for i, v in enumerate(votes):
......@@ -147,10 +147,7 @@ def plot_rdp_total(votes, sigmas):
orders = np.linspace(1., 100., endpoint=True, num=100)
orders[0] = 1.1
fig, ax = plt.subplots()
fig.set_figheight(4.5)
fig.set_figwidth(4.7)
fig.patch.set_alpha(0)
fig, ax = setup_plot()
target_answered = 2000
......@@ -171,7 +168,6 @@ def plot_rdp_total(votes, sigmas):
# label=r'Data-independent bound, $\sigma$={}'.format(int(sigma)),
# linewidth=10)
plt.xlim(xmin=1, xmax=100)
plt.ylim(ymin=0)
plt.xticks([1, 20, 40, 60, 80, 100])
......@@ -185,41 +181,32 @@ def plot_rdp_total(votes, sigmas):
plt.show()
def plot_one_curve():
fig, ax = plt.subplots()
fig.set_figheight(4.5)
fig.set_figwidth(4.7)
fig.patch.set_alpha(0)
def plot_data_ind_curve():
fig, ax = setup_plot()
orders = np.linspace(1., 10., endpoint=True, num=1000)
orders[0] = 1.01
ax.plot(
orders,
pate.rdp_data_independent_gaussian(1., orders),
alpha=.5,
color='gray',
linewidth=10)
orders,
pate.rdp_data_independent_gaussian(1., orders),
alpha=.5,
color='gray',
linewidth=10)
#plt.yticks([])
# plt.yticks([])
plt.xlim(xmin=1, xmax=10)
plt.ylim(ymin=0)
plt.xticks([1, 3, 5, 7, 9])
ax.tick_params(labelsize=14)
plt.show()
def plot_two_curves():
def plot_two_data_ind_curves():
orders = np.linspace(1., 100., endpoint=True, num=1000)
orders[0] = 1.001
fig, ax = plt.subplots()
fig.set_figheight(4.5)
fig.set_figwidth(4.7)
fig.patch.set_alpha(0)
ax.plot([], [])
ax.plot([], [])
fig, ax = setup_plot()
for sigma in [100, 150]:
ax.plot(
......@@ -242,10 +229,7 @@ def plot_two_curves():
def scatter_plot(votes, threshold, sigma1, sigma2, order):
fig, ax = plt.subplots()
fig.set_figheight(4.5)
fig.set_figwidth(4.7)
fig.patch.set_alpha(0)
fig, ax = setup_plot()
x = []
y = []
for i, v in enumerate(votes):
......@@ -259,13 +243,14 @@ def scatter_plot(votes, threshold, sigma1, sigma2, order):
y.append(pate.rdp_gaussian(logq_step2, sigma2, order))
print('Selected {} queries.'.format(len(x)))
#data_ind = pate.rdp_data_independent_gaussian(sigma, order)
#plt.plot([0, 5000], [data_ind, data_ind], color='tab:blue', linestyle='-', linewidth=2)
# Plot the data-independent curve:
# data_ind = pate.rdp_data_independent_gaussian(sigma, order)
# plt.plot([0, 5000], [data_ind, data_ind], color='tab:blue', linestyle='-', linewidth=2)
ax.set_yscale('log')
plt.xlim(xmin=0, xmax=5000)
plt.ylim(ymin=1e-300, ymax=1)
plt.yticks([1, 1e-100, 1e-200, 1e-300])
plt.scatter(x, y, s = 1, alpha=0.5)
plt.scatter(x, y, s=1, alpha=0.5)
plt.ylabel(r'RDP at $\alpha={}$'.format(order), fontsize=16)
plt.xlabel(r'max count', fontsize=16)
ax.tick_params(labelsize=14)
......@@ -278,21 +263,20 @@ def main(argv):
print('Reading raw votes from ' + fin_name)
sys.stdout.flush()
#plot_one_curve()
#plot_two_curves()
votes = np.load(fin_name)
votes = votes[:12000,] # truncate to 4000 samples
plot_data_ind_curve()
plot_two_data_ind_curves()
v1 = [2550, 2200, 250] # based on votes[2,]
#v2 = [2600, 2200, 200] # based on votes[381,]
#plot_rdp_curve_per_example(np.array([v1]), (100., 150.))
# v2 = [2600, 2200, 200] # based on votes[381,]
plot_rdp_curve_per_example(np.array([v1]), (100., 150.))
#plot_rdp_of_sigma(np.array(v1), 20.)
plot_rdp_of_sigma(np.array(v1), 20.)
votes = np.load(fin_name)
#plot_rdp_total(votes, (100., 150.))
scatter_plot(votes[:6000, ], None, None, 100, 20)
scatter_plot(votes[:6000, ], 3500, 1500, 100, 20)
plot_rdp_total(votes[:12000, ], (100., 150.))
scatter_plot(votes[:6000, ], None, None, 100, 20) # w/o thresholding
scatter_plot(votes[:6000, ], 3500, 1500, 100, 20) # with thresholding
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment