Commit 30aeec75 authored by Toby Boyd's avatar Toby Boyd Committed by GitHub
Browse files

Merge pull request #2 from tensorflow/master

Sync to tensorflow-master
parents 68a18b70 78007443
adversarial_crypto/* @dave-andersen
adversarial_text/* @rsepassi
attention_ocr/* @alexgorban
autoencoders/* @snurkabill
cognitive_mapping_and_planning/* @s-gupta
compression/* @nmjohn
differential_privacy/* @panyx0718
domain_adaptation/* @bousmalis @ddohan
im2txt/* @cshallue
inception/* @shlens @vincentvanhoucke
learning_to_remember_rare_events/* @lukaszkaiser @ofirnachum
lfads/* @jazcollins @susillo
lm_1b/* @oriolvinyals @panyx0718
namignizer/* @knathanieltucker
neural_gpu/* @lukaszkaiser
neural_programmer/* @arvind2505
next_frame_prediction/* @panyx0718
object_detection/* @jch1 @tombstone @derekjchow @jesu9 @dreamdragon
pcl_rl/* @ofirnachum
ptn/* @xcyan @arkanath @hellojas @honglaklee
real_nvp/* @laurent-dinh
rebar/* @gjtucker
resnet/* @panyx0718
skip_thoughts/* @cshallue
slim/* @sguada @nathansilberman
street/* @theraysmith
swivel/* @waterson
syntaxnet/* @calberti @andorardo
textsum/* @panyx0718 @peterjliu
transformer/* @daviddao
tutorials/embedding/* @zffchen78 @a-dai
tutorials/image/* @sherrym @shlens
tutorials/rnn/* @lukaszkaiser @ebrevdo
video_prediction/* @cbfinn
# Contributing guidelines
If you have created a model and would like to publish it here, please send us a
pull request. For those just getting started with pull reuests, GitHub has a
pull request. For those just getting started with pull requests, GitHub has a
[howto](https://help.github.com/articles/using-pull-requests/).
The code for any model in this repository is licensed under the Apache License
......
......@@ -22,12 +22,15 @@ running TensorFlow 0.12 or earlier, please
- [im2txt](im2txt): image-to-text neural network for image captioning.
- [inception](inception): deep convolutional networks for computer vision.
- [learning_to_remember_rare_events](learning_to_remember_rare_events): a large-scale life-long memory module for use in deep learning.
- [lfads](lfads): sequential variational autoencoder for analyzing neuroscience data.
- [lm_1b](lm_1b): language modeling on the one billion word benchmark.
- [namignizer](namignizer): recognize and generate names.
- [neural_gpu](neural_gpu): highly parallel neural computer.
- [neural_programmer](neural_programmer): neural network augmented with logic and mathematic operations.
- [next_frame_prediction](next_frame_prediction): probabilistic future frame synthesis via cross convolutional networks.
- [object_detection](object_detection): localizing and identifying multiple objects in a single image.
- [real_nvp](real_nvp): density estimation using real-valued non-volume preserving (real NVP) transformations.
- [rebar](rebar): low-variance, unbiased gradient estimates for discrete latent variable models.
- [resnet](resnet): deep and wide residual networks.
- [skip_thoughts](skip_thoughts): recurrent neural network sentence-to-vector encoder.
- [slim](slim): image classification models in TF-Slim.
......
......@@ -118,7 +118,7 @@ class AdversarialCrypto(object):
def model(self, collection, message, key=None):
"""The model for Alice, Bob, and Eve. If key=None, the first FC layer
takes only the Key as inputs. Otherwise, it uses both the key
takes only the message as inputs. Otherwise, it uses both the key
and the message.
Args:
......
licenses(["notice"]) # Apache 2.0
# Binaries
# ==============================================================================
py_binary(
......@@ -5,6 +7,8 @@ py_binary(
srcs = ["evaluate.py"],
deps = [
":graphs",
# google3 file dep,
# tensorflow dep,
],
)
......@@ -14,6 +18,8 @@ py_binary(
deps = [
":graphs",
":train_utils",
# google3 file dep,
# tensorflow dep,
],
)
......@@ -25,6 +31,8 @@ py_binary(
deps = [
":graphs",
":train_utils",
# google3 file dep,
# tensorflow dep,
],
)
......@@ -37,18 +45,23 @@ py_library(
":adversarial_losses",
":inputs",
":layers",
# tensorflow dep,
],
)
py_library(
name = "adversarial_losses",
srcs = ["adversarial_losses.py"],
deps = [
# tensorflow dep,
],
)
py_library(
name = "inputs",
srcs = ["inputs.py"],
deps = [
# tensorflow dep,
"//adversarial_text/data:data_utils",
],
)
......@@ -56,11 +69,18 @@ py_library(
py_library(
name = "layers",
srcs = ["layers.py"],
deps = [
# tensorflow dep,
],
)
py_library(
name = "train_utils",
srcs = ["train_utils.py"],
deps = [
# numpy dep,
# tensorflow dep,
],
)
# Tests
......@@ -71,6 +91,7 @@ py_test(
srcs = ["graphs_test.py"],
deps = [
":graphs",
# tensorflow dep,
"//adversarial_text/data:data_utils",
],
)
......@@ -56,7 +56,6 @@ $ bazel run :pretrain -- \
--embedding_dims=256 \
--rnn_cell_size=1024 \
--num_candidate_samples=1024 \
--optimizer=adam \
--batch_size=256 \
--learning_rate=0.001 \
--learning_rate_decay_factor=0.9999 \
......@@ -87,7 +86,6 @@ $ bazel run :train_classifier -- \
--rnn_cell_size=1024 \
--cl_num_layers=1 \
--cl_hidden_size=30 \
--optimizer=adam \
--batch_size=64 \
--learning_rate=0.0005 \
--learning_rate_decay_factor=0.9998 \
......@@ -96,7 +94,8 @@ $ bazel run :train_classifier -- \
--num_timesteps=400 \
--keep_prob_emb=0.5 \
--normalize_embeddings \
--adv_training_method=vat
--adv_training_method=vat \
--perturb_norm_length=5.0
```
### Evaluate on test data
......@@ -136,21 +135,21 @@ adversarial training losses). The training loop itself is defined in
### Command-Line Flags
Flags related to distributed training and the training loop itself are defined
in `train_utils.py`.
in [`train_utils.py`](https://github.com/tensorflow/models/tree/master/adversarial_text/train_utils.py).
Flags related to model hyperparameters are defined in `graphs.py`.
Flags related to model hyperparameters are defined in [`graphs.py`](https://github.com/tensorflow/models/tree/master/adversarial_text/graphs.py).
Flags related to adversarial training are defined in `adversarial_losses.py`.
Flags related to adversarial training are defined in [`adversarial_losses.py`](https://github.com/tensorflow/models/tree/master/adversarial_text/adversarial_losses.py).
Flags particular to each job are defined in the main binary files.
### Data Generation
* Vocabulary generation: `gen_vocab.py`
* Data generation: `gen_data.py`
* Vocabulary generation: [`gen_vocab.py`](https://github.com/tensorflow/models/tree/master/adversarial_text/data/gen_vocab.py)
* Data generation: [`gen_data.py`](https://github.com/tensorflow/models/tree/master/adversarial_text/data/gen_data.py)
Command-line flags defined in `document_generators.py` control which dataset is
processed and how.
Command-line flags defined in [`document_generators.py`](https://github.com/tensorflow/models/tree/master/adversarial_text/data/document_generators.py)
control which dataset is processed and how.
## Contact for Issues
......
# Copyright 2017 Google, Inc. All Rights Reserved.
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,25 +12,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Adversarial losses for text models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import tensorflow as tf
flags = tf.app.flags
FLAGS = flags.FLAGS
# Adversarial and virtual adversarial training parameters.
flags.DEFINE_float('perturb_norm_length', 0.1,
flags.DEFINE_float('perturb_norm_length', 5.0,
'Norm length of adversarial perturbation to be '
'optimized with validation')
'optimized with validation. '
'5.0 is optimal on IMDB with virtual adversarial training. ')
# Virtual adversarial training parameters
flags.DEFINE_integer('num_power_iteration', 1, 'The number of power iteration')
flags.DEFINE_float('small_constant_for_finite_diff', 1e-3,
flags.DEFINE_float('small_constant_for_finite_diff', 1e-1,
'Small constant for finite difference method')
# Parameters for building the graph
......@@ -83,19 +85,22 @@ def virtual_adversarial_loss(logits, embedded, inputs,
"""
# Stop gradient of logits. See https://arxiv.org/abs/1507.00677 for details.
logits = tf.stop_gradient(logits)
# Only care about the KL divergence on the final timestep.
weights = _end_of_seq_mask(inputs.labels)
weights = inputs.eos_weights
assert weights is not None
# Initialize perturbation with random noise.
# shape(embedded) = (batch_size, num_timesteps, embedding_dim)
d = _mask_by_length(tf.random_normal(shape=tf.shape(embedded)), inputs.length)
d = tf.random_normal(shape=tf.shape(embedded))
# Perform finite difference method and power iteration.
# See Eq.(8) in the paper http://arxiv.org/pdf/1507.00677.pdf,
# Adding small noise to input and taking gradient with respect to the noise
# corresponds to 1 power iteration.
for _ in xrange(FLAGS.num_power_iteration):
d = _scale_l2(d, FLAGS.small_constant_for_finite_diff)
d = _scale_l2(
_mask_by_length(d, inputs.length), FLAGS.small_constant_for_finite_diff)
d_logits = logits_from_embedding_fn(embedded + d)
kl = _kl_divergence_with_logits(logits, d_logits, weights)
d, = tf.gradients(
......@@ -104,8 +109,7 @@ def virtual_adversarial_loss(logits, embedded, inputs,
aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
d = tf.stop_gradient(d)
perturb = _scale_l2(
_mask_by_length(d, inputs.length), FLAGS.perturb_norm_length)
perturb = _scale_l2(d, FLAGS.perturb_norm_length)
vadv_logits = logits_from_embedding_fn(embedded + perturb)
return _kl_divergence_with_logits(logits, vadv_logits, weights)
......@@ -136,7 +140,8 @@ def virtual_adversarial_loss_bidir(logits, embedded, inputs,
"""Virtual adversarial loss for bidirectional models."""
logits = tf.stop_gradient(logits)
f_inputs, _ = inputs
weights = _end_of_seq_mask(f_inputs.labels)
weights = f_inputs.eos_weights
assert weights is not None
perturbs = [
_mask_by_length(tf.random_normal(shape=tf.shape(emb)), f_inputs.length)
......@@ -155,10 +160,7 @@ def virtual_adversarial_loss_bidir(logits, embedded, inputs,
aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
perturbs = [tf.stop_gradient(d) for d in perturbs]
perturbs = [
_scale_l2(_mask_by_length(d, f_inputs.length), FLAGS.perturb_norm_length)
for d in perturbs
]
perturbs = [_scale_l2(d, FLAGS.perturb_norm_length) for d in perturbs]
vadv_logits = logits_from_embedding_fn(
[emb + d for (emb, d) in zip(embedded, perturbs)])
return _kl_divergence_with_logits(logits, vadv_logits, weights)
......@@ -167,7 +169,9 @@ def virtual_adversarial_loss_bidir(logits, embedded, inputs,
def _mask_by_length(t, length):
"""Mask t, 3-D [batch, time, dim], by length, 1-D [batch,]."""
maxlen = t.get_shape().as_list()[1]
mask = tf.sequence_mask(length, maxlen=maxlen)
# Subtract 1 from length to prevent the perturbation from going on 'eos'
mask = tf.sequence_mask(length - 1, maxlen=maxlen)
mask = tf.expand_dims(tf.cast(mask, tf.float32), -1)
# shape(mask) = (batch, num_timesteps, 1)
return t * mask
......@@ -175,32 +179,16 @@ def _mask_by_length(t, length):
def _scale_l2(x, norm_length):
# shape(x) = (batch, num_timesteps, d)
# Divide x by max(abs(x)) for a numerically stable L2 norm.
# 2norm(x) = a * 2norm(x/a)
# Scale over the full sequence, dims (1, 2)
alpha = tf.reduce_max(tf.abs(x), (1, 2), keep_dims=True) + 1e-12
l2_norm = alpha * tf.sqrt(tf.reduce_sum(tf.pow(x / alpha, 2), (1, 2),
keep_dims=True) + 1e-6)
l2_norm = alpha * tf.sqrt(
tf.reduce_sum(tf.pow(x / alpha, 2), (1, 2), keep_dims=True) + 1e-6)
x_unit = x / l2_norm
return norm_length * x_unit
def _end_of_seq_mask(tokens):
"""Generate a mask for the EOS token (1.0 on EOS, 0.0 otherwise).
Args:
tokens: 1-D integer tensor [num_timesteps*batch_size]. Each element is an
id from the vocab.
Returns:
Float tensor same shape as tokens, whose values are 1.0 on the end of
sequence and 0.0 on the others.
"""
eos_id = FLAGS.vocab_size - 1
return tf.cast(tf.equal(tokens, eos_id), tf.float32)
def _kl_divergence_with_logits(q_logits, p_logits, weights):
"""Returns weighted KL divergence between distributions q and p.
......@@ -218,21 +206,20 @@ def _kl_divergence_with_logits(q_logits, p_logits, weights):
# For logistic regression
if FLAGS.num_classes == 2:
q = tf.nn.sigmoid(q_logits)
p = tf.nn.sigmoid(p_logits)
kl = (-tf.nn.sigmoid_cross_entropy_with_logits(logits=q_logits, labels=q) +
tf.nn.sigmoid_cross_entropy_with_logits(logits=p_logits, labels=q))
kl = tf.squeeze(kl)
# For softmax regression
else:
q = tf.nn.softmax(q_logits)
p = tf.nn.softmax(p_logits)
kl = tf.reduce_sum(q * (tf.log(q) - tf.log(p)), 1)
kl = tf.reduce_sum(
q * (tf.nn.log_softmax(q_logits) - tf.nn.log_softmax(p_logits)), 1)
num_labels = tf.reduce_sum(weights)
num_labels = tf.where(tf.equal(num_labels, 0.), 1., num_labels)
kl.get_shape().assert_has_rank(2)
kl.get_shape().assert_has_rank(1)
weights.get_shape().assert_has_rank(1)
loss = tf.identity(tf.reduce_sum(tf.expand_dims(weights, -1) * kl) /
num_labels, name='kl')
loss = tf.identity(tf.reduce_sum(weights * kl) / num_labels, name='kl')
return loss
licenses(["notice"]) # Apache 2.0
package(
default_visibility = [
"//adversarial_text:__subpackages__",
......@@ -10,6 +12,7 @@ py_binary(
deps = [
":data_utils",
":document_generators",
# tensorflow dep,
],
)
......@@ -19,17 +22,24 @@ py_binary(
deps = [
":data_utils",
":document_generators",
# tensorflow dep,
],
)
py_library(
name = "document_generators",
srcs = ["document_generators.py"],
deps = [
# tensorflow dep,
],
)
py_library(
name = "data_utils",
srcs = ["data_utils.py"],
deps = [
# tensorflow dep,
],
)
py_test(
......@@ -37,5 +47,6 @@ py_test(
srcs = ["data_utils_test.py"],
deps = [
":data_utils",
# tensorflow dep,
],
)
# Copyright 2017 Google, Inc. All Rights Reserved.
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,13 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for generating/preprocessing data for adversarial text models."""
import operator
import os
import random
import re
# Dependency imports
import tensorflow as tf
EOS_TOKEN = '</s>'
......@@ -215,13 +217,17 @@ def build_lm_sequence(seq):
Returns:
SequenceWrapper with `seq` tokens copied over to output sequence tokens and
labels (offset by 1, i.e. predict next token) with weights set to 1.0.
labels (offset by 1, i.e. predict next token) with weights set to 1.0,
except for <eos> token.
"""
lm_seq = SequenceWrapper()
for i, timestep in enumerate(seq[:-1]):
lm_seq.add_timestep().set_token(timestep.token).set_label(
seq[i + 1].token).set_weight(1.0)
for i, timestep in enumerate(seq):
if i == len(seq) - 1:
lm_seq.add_timestep().set_token(timestep.token).set_label(
seq[i].token).set_weight(0.0)
else:
lm_seq.add_timestep().set_token(timestep.token).set_label(
seq[i + 1].token).set_weight(1.0)
return lm_seq
......
# Copyright 2017 Google, Inc. All Rights Reserved.
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for data_utils."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import tensorflow as tf
from adversarial_text.data import data_utils
......@@ -91,9 +92,16 @@ class DataUtilsTest(tf.test.TestCase):
seq = self._buildDummySequence()
lm_seq = data.build_lm_sequence(seq)
for i, ts in enumerate(lm_seq):
self.assertEqual(ts.token, i)
self.assertEqual(ts.label, i + 1)
self.assertEqual(ts.weight, 1.0)
# For end of sequence, the token and label should be same, and weight
# should be 0.0.
if i == len(lm_seq) - 1:
self.assertEqual(ts.token, i)
self.assertEqual(ts.label, i)
self.assertEqual(ts.weight, 0.0)
else:
self.assertEqual(ts.token, i)
self.assertEqual(ts.label, i + 1)
self.assertEqual(ts.weight, 1.0)
def testBuildSAESeq(self):
seq = self._buildDummySequence()
......
# Copyright 2017 Google, Inc. All Rights Reserved.
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Input readers and document/token generators for datasets."""
from __future__ import absolute_import
from __future__ import division
......@@ -23,6 +22,8 @@ import csv
import os
import random
# Dependency imports
import tensorflow as tf
from adversarial_text.data import data_utils
......@@ -60,7 +61,6 @@ flags.DEFINE_string('rcv1_input_dir', '',
flags.DEFINE_string('rt_input_dir', '',
'The Rotten Tomatoes dataset input directory.')
# The amazon reviews input file to use in either the RT or IMDB datasets.
flags.DEFINE_string('amazon_unlabeled_input_file', '',
'The unlabeled Amazon Reviews dataset input file. If set, '
......@@ -211,8 +211,12 @@ def imdb_documents(dataset='train',
if FLAGS.amazon_unlabeled_input_file and include_unlabeled:
with open(FLAGS.amazon_unlabeled_input_file) as rt_f:
for content in rt_f:
yield Document(content=content, is_validation=False, is_test=False,
label=None, add_tokens=False)
yield Document(
content=content,
is_validation=False,
is_test=False,
label=None,
add_tokens=False)
def dbpedia_documents(dataset='train',
......@@ -265,7 +269,8 @@ def rcv1_documents(dataset='train',
# pylint:disable=line-too-long
"""Generates Documents for Reuters Corpus (rcv1) dataset.
Dataset described at http://www.ai.mit.edu/projects/jmlr/papers/volume5/lewis04a/lyrl2004_rcv1v2_README.htm
Dataset described at
http://www.ai.mit.edu/projects/jmlr/papers/volume5/lewis04a/lyrl2004_rcv1v2_README.htm
Args:
dataset: str, identifies the csv file within the rcv1 data directory.
......@@ -354,17 +359,25 @@ def rt_documents(dataset='train',
if class_label is None:
# Process Amazon Review data for unlabeled dataset
if content.startswith('review/text'):
yield Document(content=content, is_validation=False,
is_test=False, label=None, add_tokens=False)
yield Document(
content=content,
is_validation=False,
is_test=False,
label=None,
add_tokens=False)
else:
# 10% of the data is randomly held out for the validation set and
# another 10% of it is randomly held out for the test set
random_int = random.randint(1, 10)
is_validation = random_int == 1
is_test = random_int == 2
if (is_test and dataset != 'test') or (
is_validation and not include_validation):
if (is_test and dataset != 'test') or (is_validation and
not include_validation):
continue
yield Document(content=content, is_validation=is_validation,
is_test=is_test, label=class_label, add_tokens=True)
yield Document(
content=content,
is_validation=is_validation,
is_test=is_test,
label=class_label,
add_tokens=True)
# Copyright 2017 Google, Inc. All Rights Reserved.
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Create TFRecord files of SequenceExample protos from dataset.
Constructs 3 datasets:
......@@ -31,6 +30,8 @@ from __future__ import print_function
import os
import string
# Dependency imports
import tensorflow as tf
from adversarial_text.data import data_utils
......@@ -197,6 +198,7 @@ def generate_test_data(vocab_ids, writer_lm_all, writer_seq_ae_all):
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
tf.logging.info('Assigning vocabulary ids...')
vocab_ids = make_vocab_ids(
FLAGS.vocab_file or os.path.join(FLAGS.output_dir, 'vocab.txt'))
......
# Copyright 2017 Google, Inc. All Rights Reserved.
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Generates vocabulary and term frequency files for datasets."""
from __future__ import absolute_import
from __future__ import division
......@@ -20,6 +19,8 @@ from __future__ import print_function
from collections import defaultdict
# Dependency imports
import tensorflow as tf
from adversarial_text.data import data_utils
......@@ -66,6 +67,7 @@ def fill_vocab_from_doc(doc, vocab_freqs, doc_counts):
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
vocab_freqs = defaultdict(int)
doc_counts = defaultdict(int)
......
# Copyright 2017 Google, Inc. All Rights Reserved.
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Evaluates text classification model."""
from __future__ import absolute_import
......@@ -22,6 +21,8 @@ from __future__ import print_function
import math
import time
# Dependency imports
import tensorflow as tf
import graphs
......@@ -100,6 +101,7 @@ def run_eval(eval_ops, summary_writer, saver):
def _log_values(sess, value_ops, summary_writer=None):
"""Evaluate, log, and write summaries of the eval metrics in value_ops."""
metric_names, value_ops = zip(*value_ops.items())
values = sess.run(value_ops)
......
# Copyright 2017 Google, Inc. All Rights Reserved.
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Virtual adversarial text models."""
from __future__ import absolute_import
from __future__ import division
......@@ -20,6 +19,9 @@ from __future__ import print_function
import csv
import os
# Dependency imports
import tensorflow as tf
import adversarial_losses as adv_lib
......@@ -81,7 +83,8 @@ flags.DEFINE_integer('replicas_to_aggregate', 1,
# Regularization
flags.DEFINE_float('max_grad_norm', 1.0,
'Clip the global gradient norm to this value.')
flags.DEFINE_float('keep_prob_emb', 1.0, 'keep probability on embedding layer')
flags.DEFINE_float('keep_prob_emb', 1.0, 'keep probability on embedding layer. '
'0.5 is optimal on IMDB with virtual adversarial training.')
flags.DEFINE_float('keep_prob_lstm_out', 1.0,
'keep probability on lstm output.')
flags.DEFINE_float('keep_prob_cl_hidden', 1.0,
......@@ -249,8 +252,7 @@ class VatxtModel(object):
eval_ops = {
'accuracy':
tf.contrib.metrics.streaming_accuracy(
layers_lib.predictions(logits), inputs.labels,
inputs.weights)
layers_lib.predictions(logits), inputs.labels, inputs.weights)
}
with tf.control_dependencies([inputs.save_state(next_state)]):
......@@ -610,7 +612,8 @@ def _inputs(dataset='train', pretrain=False, bidir=False):
state_size=FLAGS.rnn_cell_size,
num_layers=FLAGS.rnn_num_layers,
batch_size=FLAGS.batch_size,
unroll_steps=FLAGS.num_timesteps)
unroll_steps=FLAGS.num_timesteps,
eos_id=FLAGS.vocab_size - 1)
def _get_vocab_freqs():
......
# Copyright 2017 Google, Inc. All Rights Reserved.
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for graphs."""
from __future__ import absolute_import
from __future__ import division
......@@ -26,6 +25,8 @@ import shutil
import string
import tempfile
# Dependency imports
import tensorflow as tf
import graphs
......
# Copyright 2017 Google, Inc. All Rights Reserved.
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Input utils for virtual adversarial text classification."""
from __future__ import absolute_import
......@@ -20,6 +19,9 @@ from __future__ import division
from __future__ import print_function
import os
# Dependency imports
import tensorflow as tf
from adversarial_text.data import data_utils
......@@ -28,7 +30,12 @@ from adversarial_text.data import data_utils
class VatxtInput(object):
"""Wrapper around NextQueuedSequenceBatch."""
def __init__(self, batch, state_name=None, tokens=None, num_states=0):
def __init__(self,
batch,
state_name=None,
tokens=None,
num_states=0,
eos_id=None):
"""Construct VatxtInput.
Args:
......@@ -36,6 +43,7 @@ class VatxtInput(object):
state_name: str, name of state to fetch and save.
tokens: int Tensor, tokens. Defaults to batch's F_TOKEN_ID sequence.
num_states: int The number of states to store.
eos_id: int Id of end of Sequence.
"""
self._batch = batch
self._state_name = state_name
......@@ -58,6 +66,14 @@ class VatxtInput(object):
l = tf.reshape(l, [-1])
self._labels = l
# eos weights
self._eos_weights = None
if eos_id:
ew = tf.cast(tf.equal(self._tokens, eos_id), tf.float32)
ew = tf.transpose(ew, [1, 0])
ew = tf.reshape(ew, [-1])
self._eos_weights = ew
@property
def tokens(self):
return self._tokens
......@@ -66,6 +82,10 @@ class VatxtInput(object):
def weights(self):
return self._weights
@property
def eos_weights(self):
return self._eos_weights
@property
def labels(self):
return self._labels
......@@ -246,7 +266,8 @@ def inputs(data_dir=None,
state_size=None,
num_layers=0,
batch_size=32,
unroll_steps=100):
unroll_steps=100,
eos_id=None):
"""Inputs for text model.
Args:
......@@ -260,7 +281,7 @@ def inputs(data_dir=None,
num_layers: int, the number of LSTM layers.
batch_size: int, batch size.
unroll_steps: int, number of timesteps to unroll for TBTT.
eos_id: int, id of end of sequence. used for the kl weights on vat
Returns:
Instance of VatxtInput (x2 if bidir=True and pretrain=True, i.e. forward and
reverse).
......@@ -280,9 +301,15 @@ def inputs(data_dir=None,
state_size, num_layers, unroll_steps,
batch_size)
forward_input = VatxtInput(
forward_batch, state_name=state_name, num_states=num_layers)
forward_batch,
state_name=state_name,
num_states=num_layers,
eos_id=eos_id)
reverse_input = VatxtInput(
reverse_batch, state_name=state_name_rev, num_states=num_layers)
reverse_batch,
state_name=state_name_rev,
num_states=num_layers,
eos_id=eos_id)
return forward_input, reverse_input
elif bidir:
......@@ -322,4 +349,5 @@ def inputs(data_dir=None,
unroll_steps,
batch_size,
bidir_input=False)
return VatxtInput(batch, state_name=state_name, num_states=num_layers)
return VatxtInput(
batch, state_name=state_name, num_states=num_layers, eos_id=eos_id)
# Copyright 2017 Google, Inc. All Rights Reserved.
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Layers for VatxtModel."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
# Dependency imports
import tensorflow as tf
K = tf.contrib.keras
......@@ -34,7 +34,7 @@ def cl_logits_subgraph(layer_sizes, input_size, num_classes, keep_prob=1.):
subgraph.add(K.layers.Dense(layer_size, activation='relu'))
if keep_prob < 1.:
subgraph.add(K.layers.Dropout(keep_prob))
subgraph.add(K.layers.Dropout(1. - keep_prob))
subgraph.add(K.layers.Dense(1 if num_classes == 2 else num_classes))
return subgraph
......@@ -76,7 +76,14 @@ class Embedding(K.layers.Layer):
def call(self, x):
embedded = tf.nn.embedding_lookup(self.var, x)
if self.keep_prob < 1.:
embedded = tf.nn.dropout(embedded, self.keep_prob)
shape = embedded.get_shape().as_list()
# Use same dropout masks at each timestep with specifying noise_shape.
# This slightly improves performance.
# Please see https://arxiv.org/abs/1512.05287 for the theoretical
# explanation.
embedded = tf.nn.dropout(
embedded, self.keep_prob, noise_shape=(shape[0], 1, shape[2]))
return embedded
def _normalize(self, emb):
......@@ -153,11 +160,11 @@ class SoftmaxLoss(K.layers.Layer):
self.lin_w = self.add_weight(
shape=(input_shape[-1], self.vocab_size),
name='lm_lin_w',
initializer='glorot_uniform')
initializer=K.initializers.glorot_uniform())
self.lin_b = self.add_weight(
shape=(self.vocab_size,),
name='lm_lin_b',
initializer='glorot_uniform')
initializer=K.initializers.glorot_uniform())
super(SoftmaxLoss, self).build(input_shape)
......
# Copyright 2017 Google, Inc. All Rights Reserved.
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,18 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Pretrains a recurrent language model.
Computational time:
5 days to train 100000 steps on 1 layer 1024 hidden units LSTM,
256 embeddings, 400 truncated BP, 64 minibatch and on 4 GPU with
SyncReplicasOptimizer, that is the total minibatch is 256.
2 days to train 100000 steps on 1 layer 1024 hidden units LSTM,
256 embeddings, 400 truncated BP, 256 minibatch and on single GPU (Pascal
Titan X, cuDNNv5).
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import tensorflow as tf
import graphs
......
# Copyright 2017 Google, Inc. All Rights Reserved.
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,17 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trains LSTM text classification model.
Model trains with adversarial or virtual adversarial training.
Computational time:
6 hours to train 10000 steps without adversarial or virtual adversarial
1.8 hours to train 10000 steps without adversarial or virtual adversarial
training, on 1 layer 1024 hidden units LSTM, 256 embeddings, 400 truncated
BP, 64 minibatch and on single GPU.
BP, 64 minibatch and on single GPU (Pascal Titan X, cuDNNv5).
12 hours to train 10000 steps with adversarial or virtual adversarial
4 hours to train 10000 steps with adversarial or virtual adversarial
training, with above condition.
To initialize embedding and LSTM cell weights from a pretrained model, set
......@@ -32,6 +31,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import tensorflow as tf
import graphs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment