Commit 30aeec75 authored by Toby Boyd's avatar Toby Boyd Committed by GitHub
Browse files

Merge pull request #2 from tensorflow/master

Sync to tensorflow-master
parents 68a18b70 78007443
adversarial_crypto/* @dave-andersen
adversarial_text/* @rsepassi
attention_ocr/* @alexgorban
autoencoders/* @snurkabill
cognitive_mapping_and_planning/* @s-gupta
compression/* @nmjohn
differential_privacy/* @panyx0718
domain_adaptation/* @bousmalis @ddohan
im2txt/* @cshallue
inception/* @shlens @vincentvanhoucke
learning_to_remember_rare_events/* @lukaszkaiser @ofirnachum
lfads/* @jazcollins @susillo
lm_1b/* @oriolvinyals @panyx0718
namignizer/* @knathanieltucker
neural_gpu/* @lukaszkaiser
neural_programmer/* @arvind2505
next_frame_prediction/* @panyx0718
object_detection/* @jch1 @tombstone @derekjchow @jesu9 @dreamdragon
pcl_rl/* @ofirnachum
ptn/* @xcyan @arkanath @hellojas @honglaklee
real_nvp/* @laurent-dinh
rebar/* @gjtucker
resnet/* @panyx0718
skip_thoughts/* @cshallue
slim/* @sguada @nathansilberman
street/* @theraysmith
swivel/* @waterson
syntaxnet/* @calberti @andorardo
textsum/* @panyx0718 @peterjliu
transformer/* @daviddao
tutorials/embedding/* @zffchen78 @a-dai
tutorials/image/* @sherrym @shlens
tutorials/rnn/* @lukaszkaiser @ebrevdo
video_prediction/* @cbfinn
# Contributing guidelines # Contributing guidelines
If you have created a model and would like to publish it here, please send us a If you have created a model and would like to publish it here, please send us a
pull request. For those just getting started with pull reuests, GitHub has a pull request. For those just getting started with pull requests, GitHub has a
[howto](https://help.github.com/articles/using-pull-requests/). [howto](https://help.github.com/articles/using-pull-requests/).
The code for any model in this repository is licensed under the Apache License The code for any model in this repository is licensed under the Apache License
......
...@@ -22,12 +22,15 @@ running TensorFlow 0.12 or earlier, please ...@@ -22,12 +22,15 @@ running TensorFlow 0.12 or earlier, please
- [im2txt](im2txt): image-to-text neural network for image captioning. - [im2txt](im2txt): image-to-text neural network for image captioning.
- [inception](inception): deep convolutional networks for computer vision. - [inception](inception): deep convolutional networks for computer vision.
- [learning_to_remember_rare_events](learning_to_remember_rare_events): a large-scale life-long memory module for use in deep learning. - [learning_to_remember_rare_events](learning_to_remember_rare_events): a large-scale life-long memory module for use in deep learning.
- [lfads](lfads): sequential variational autoencoder for analyzing neuroscience data.
- [lm_1b](lm_1b): language modeling on the one billion word benchmark. - [lm_1b](lm_1b): language modeling on the one billion word benchmark.
- [namignizer](namignizer): recognize and generate names. - [namignizer](namignizer): recognize and generate names.
- [neural_gpu](neural_gpu): highly parallel neural computer. - [neural_gpu](neural_gpu): highly parallel neural computer.
- [neural_programmer](neural_programmer): neural network augmented with logic and mathematic operations. - [neural_programmer](neural_programmer): neural network augmented with logic and mathematic operations.
- [next_frame_prediction](next_frame_prediction): probabilistic future frame synthesis via cross convolutional networks. - [next_frame_prediction](next_frame_prediction): probabilistic future frame synthesis via cross convolutional networks.
- [object_detection](object_detection): localizing and identifying multiple objects in a single image.
- [real_nvp](real_nvp): density estimation using real-valued non-volume preserving (real NVP) transformations. - [real_nvp](real_nvp): density estimation using real-valued non-volume preserving (real NVP) transformations.
- [rebar](rebar): low-variance, unbiased gradient estimates for discrete latent variable models.
- [resnet](resnet): deep and wide residual networks. - [resnet](resnet): deep and wide residual networks.
- [skip_thoughts](skip_thoughts): recurrent neural network sentence-to-vector encoder. - [skip_thoughts](skip_thoughts): recurrent neural network sentence-to-vector encoder.
- [slim](slim): image classification models in TF-Slim. - [slim](slim): image classification models in TF-Slim.
......
...@@ -118,7 +118,7 @@ class AdversarialCrypto(object): ...@@ -118,7 +118,7 @@ class AdversarialCrypto(object):
def model(self, collection, message, key=None): def model(self, collection, message, key=None):
"""The model for Alice, Bob, and Eve. If key=None, the first FC layer """The model for Alice, Bob, and Eve. If key=None, the first FC layer
takes only the Key as inputs. Otherwise, it uses both the key takes only the message as inputs. Otherwise, it uses both the key
and the message. and the message.
Args: Args:
......
licenses(["notice"]) # Apache 2.0
# Binaries # Binaries
# ============================================================================== # ==============================================================================
py_binary( py_binary(
...@@ -5,6 +7,8 @@ py_binary( ...@@ -5,6 +7,8 @@ py_binary(
srcs = ["evaluate.py"], srcs = ["evaluate.py"],
deps = [ deps = [
":graphs", ":graphs",
# google3 file dep,
# tensorflow dep,
], ],
) )
...@@ -14,6 +18,8 @@ py_binary( ...@@ -14,6 +18,8 @@ py_binary(
deps = [ deps = [
":graphs", ":graphs",
":train_utils", ":train_utils",
# google3 file dep,
# tensorflow dep,
], ],
) )
...@@ -25,6 +31,8 @@ py_binary( ...@@ -25,6 +31,8 @@ py_binary(
deps = [ deps = [
":graphs", ":graphs",
":train_utils", ":train_utils",
# google3 file dep,
# tensorflow dep,
], ],
) )
...@@ -37,18 +45,23 @@ py_library( ...@@ -37,18 +45,23 @@ py_library(
":adversarial_losses", ":adversarial_losses",
":inputs", ":inputs",
":layers", ":layers",
# tensorflow dep,
], ],
) )
py_library( py_library(
name = "adversarial_losses", name = "adversarial_losses",
srcs = ["adversarial_losses.py"], srcs = ["adversarial_losses.py"],
deps = [
# tensorflow dep,
],
) )
py_library( py_library(
name = "inputs", name = "inputs",
srcs = ["inputs.py"], srcs = ["inputs.py"],
deps = [ deps = [
# tensorflow dep,
"//adversarial_text/data:data_utils", "//adversarial_text/data:data_utils",
], ],
) )
...@@ -56,11 +69,18 @@ py_library( ...@@ -56,11 +69,18 @@ py_library(
py_library( py_library(
name = "layers", name = "layers",
srcs = ["layers.py"], srcs = ["layers.py"],
deps = [
# tensorflow dep,
],
) )
py_library( py_library(
name = "train_utils", name = "train_utils",
srcs = ["train_utils.py"], srcs = ["train_utils.py"],
deps = [
# numpy dep,
# tensorflow dep,
],
) )
# Tests # Tests
...@@ -71,6 +91,7 @@ py_test( ...@@ -71,6 +91,7 @@ py_test(
srcs = ["graphs_test.py"], srcs = ["graphs_test.py"],
deps = [ deps = [
":graphs", ":graphs",
# tensorflow dep,
"//adversarial_text/data:data_utils", "//adversarial_text/data:data_utils",
], ],
) )
...@@ -56,7 +56,6 @@ $ bazel run :pretrain -- \ ...@@ -56,7 +56,6 @@ $ bazel run :pretrain -- \
--embedding_dims=256 \ --embedding_dims=256 \
--rnn_cell_size=1024 \ --rnn_cell_size=1024 \
--num_candidate_samples=1024 \ --num_candidate_samples=1024 \
--optimizer=adam \
--batch_size=256 \ --batch_size=256 \
--learning_rate=0.001 \ --learning_rate=0.001 \
--learning_rate_decay_factor=0.9999 \ --learning_rate_decay_factor=0.9999 \
...@@ -87,7 +86,6 @@ $ bazel run :train_classifier -- \ ...@@ -87,7 +86,6 @@ $ bazel run :train_classifier -- \
--rnn_cell_size=1024 \ --rnn_cell_size=1024 \
--cl_num_layers=1 \ --cl_num_layers=1 \
--cl_hidden_size=30 \ --cl_hidden_size=30 \
--optimizer=adam \
--batch_size=64 \ --batch_size=64 \
--learning_rate=0.0005 \ --learning_rate=0.0005 \
--learning_rate_decay_factor=0.9998 \ --learning_rate_decay_factor=0.9998 \
...@@ -96,7 +94,8 @@ $ bazel run :train_classifier -- \ ...@@ -96,7 +94,8 @@ $ bazel run :train_classifier -- \
--num_timesteps=400 \ --num_timesteps=400 \
--keep_prob_emb=0.5 \ --keep_prob_emb=0.5 \
--normalize_embeddings \ --normalize_embeddings \
--adv_training_method=vat --adv_training_method=vat \
--perturb_norm_length=5.0
``` ```
### Evaluate on test data ### Evaluate on test data
...@@ -136,21 +135,21 @@ adversarial training losses). The training loop itself is defined in ...@@ -136,21 +135,21 @@ adversarial training losses). The training loop itself is defined in
### Command-Line Flags ### Command-Line Flags
Flags related to distributed training and the training loop itself are defined Flags related to distributed training and the training loop itself are defined
in `train_utils.py`. in [`train_utils.py`](https://github.com/tensorflow/models/tree/master/adversarial_text/train_utils.py).
Flags related to model hyperparameters are defined in `graphs.py`. Flags related to model hyperparameters are defined in [`graphs.py`](https://github.com/tensorflow/models/tree/master/adversarial_text/graphs.py).
Flags related to adversarial training are defined in `adversarial_losses.py`. Flags related to adversarial training are defined in [`adversarial_losses.py`](https://github.com/tensorflow/models/tree/master/adversarial_text/adversarial_losses.py).
Flags particular to each job are defined in the main binary files. Flags particular to each job are defined in the main binary files.
### Data Generation ### Data Generation
* Vocabulary generation: `gen_vocab.py` * Vocabulary generation: [`gen_vocab.py`](https://github.com/tensorflow/models/tree/master/adversarial_text/data/gen_vocab.py)
* Data generation: `gen_data.py` * Data generation: [`gen_data.py`](https://github.com/tensorflow/models/tree/master/adversarial_text/data/gen_data.py)
Command-line flags defined in `document_generators.py` control which dataset is Command-line flags defined in [`document_generators.py`](https://github.com/tensorflow/models/tree/master/adversarial_text/data/document_generators.py)
processed and how. control which dataset is processed and how.
## Contact for Issues ## Contact for Issues
......
# Copyright 2017 Google, Inc. All Rights Reserved. # Copyright 2017 Google Inc. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,25 +12,27 @@ ...@@ -12,25 +12,27 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Adversarial losses for text models.""" """Adversarial losses for text models."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
# Dependency imports
import tensorflow as tf import tensorflow as tf
flags = tf.app.flags flags = tf.app.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
# Adversarial and virtual adversarial training parameters. # Adversarial and virtual adversarial training parameters.
flags.DEFINE_float('perturb_norm_length', 0.1, flags.DEFINE_float('perturb_norm_length', 5.0,
'Norm length of adversarial perturbation to be ' 'Norm length of adversarial perturbation to be '
'optimized with validation') 'optimized with validation. '
'5.0 is optimal on IMDB with virtual adversarial training. ')
# Virtual adversarial training parameters # Virtual adversarial training parameters
flags.DEFINE_integer('num_power_iteration', 1, 'The number of power iteration') flags.DEFINE_integer('num_power_iteration', 1, 'The number of power iteration')
flags.DEFINE_float('small_constant_for_finite_diff', 1e-3, flags.DEFINE_float('small_constant_for_finite_diff', 1e-1,
'Small constant for finite difference method') 'Small constant for finite difference method')
# Parameters for building the graph # Parameters for building the graph
...@@ -83,19 +85,22 @@ def virtual_adversarial_loss(logits, embedded, inputs, ...@@ -83,19 +85,22 @@ def virtual_adversarial_loss(logits, embedded, inputs,
""" """
# Stop gradient of logits. See https://arxiv.org/abs/1507.00677 for details. # Stop gradient of logits. See https://arxiv.org/abs/1507.00677 for details.
logits = tf.stop_gradient(logits) logits = tf.stop_gradient(logits)
# Only care about the KL divergence on the final timestep. # Only care about the KL divergence on the final timestep.
weights = _end_of_seq_mask(inputs.labels) weights = inputs.eos_weights
assert weights is not None
# Initialize perturbation with random noise. # Initialize perturbation with random noise.
# shape(embedded) = (batch_size, num_timesteps, embedding_dim) # shape(embedded) = (batch_size, num_timesteps, embedding_dim)
d = _mask_by_length(tf.random_normal(shape=tf.shape(embedded)), inputs.length) d = tf.random_normal(shape=tf.shape(embedded))
# Perform finite difference method and power iteration. # Perform finite difference method and power iteration.
# See Eq.(8) in the paper http://arxiv.org/pdf/1507.00677.pdf, # See Eq.(8) in the paper http://arxiv.org/pdf/1507.00677.pdf,
# Adding small noise to input and taking gradient with respect to the noise # Adding small noise to input and taking gradient with respect to the noise
# corresponds to 1 power iteration. # corresponds to 1 power iteration.
for _ in xrange(FLAGS.num_power_iteration): for _ in xrange(FLAGS.num_power_iteration):
d = _scale_l2(d, FLAGS.small_constant_for_finite_diff) d = _scale_l2(
_mask_by_length(d, inputs.length), FLAGS.small_constant_for_finite_diff)
d_logits = logits_from_embedding_fn(embedded + d) d_logits = logits_from_embedding_fn(embedded + d)
kl = _kl_divergence_with_logits(logits, d_logits, weights) kl = _kl_divergence_with_logits(logits, d_logits, weights)
d, = tf.gradients( d, = tf.gradients(
...@@ -104,8 +109,7 @@ def virtual_adversarial_loss(logits, embedded, inputs, ...@@ -104,8 +109,7 @@ def virtual_adversarial_loss(logits, embedded, inputs,
aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N) aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
d = tf.stop_gradient(d) d = tf.stop_gradient(d)
perturb = _scale_l2( perturb = _scale_l2(d, FLAGS.perturb_norm_length)
_mask_by_length(d, inputs.length), FLAGS.perturb_norm_length)
vadv_logits = logits_from_embedding_fn(embedded + perturb) vadv_logits = logits_from_embedding_fn(embedded + perturb)
return _kl_divergence_with_logits(logits, vadv_logits, weights) return _kl_divergence_with_logits(logits, vadv_logits, weights)
...@@ -136,7 +140,8 @@ def virtual_adversarial_loss_bidir(logits, embedded, inputs, ...@@ -136,7 +140,8 @@ def virtual_adversarial_loss_bidir(logits, embedded, inputs,
"""Virtual adversarial loss for bidirectional models.""" """Virtual adversarial loss for bidirectional models."""
logits = tf.stop_gradient(logits) logits = tf.stop_gradient(logits)
f_inputs, _ = inputs f_inputs, _ = inputs
weights = _end_of_seq_mask(f_inputs.labels) weights = f_inputs.eos_weights
assert weights is not None
perturbs = [ perturbs = [
_mask_by_length(tf.random_normal(shape=tf.shape(emb)), f_inputs.length) _mask_by_length(tf.random_normal(shape=tf.shape(emb)), f_inputs.length)
...@@ -155,10 +160,7 @@ def virtual_adversarial_loss_bidir(logits, embedded, inputs, ...@@ -155,10 +160,7 @@ def virtual_adversarial_loss_bidir(logits, embedded, inputs,
aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N) aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
perturbs = [tf.stop_gradient(d) for d in perturbs] perturbs = [tf.stop_gradient(d) for d in perturbs]
perturbs = [ perturbs = [_scale_l2(d, FLAGS.perturb_norm_length) for d in perturbs]
_scale_l2(_mask_by_length(d, f_inputs.length), FLAGS.perturb_norm_length)
for d in perturbs
]
vadv_logits = logits_from_embedding_fn( vadv_logits = logits_from_embedding_fn(
[emb + d for (emb, d) in zip(embedded, perturbs)]) [emb + d for (emb, d) in zip(embedded, perturbs)])
return _kl_divergence_with_logits(logits, vadv_logits, weights) return _kl_divergence_with_logits(logits, vadv_logits, weights)
...@@ -167,7 +169,9 @@ def virtual_adversarial_loss_bidir(logits, embedded, inputs, ...@@ -167,7 +169,9 @@ def virtual_adversarial_loss_bidir(logits, embedded, inputs,
def _mask_by_length(t, length): def _mask_by_length(t, length):
"""Mask t, 3-D [batch, time, dim], by length, 1-D [batch,].""" """Mask t, 3-D [batch, time, dim], by length, 1-D [batch,]."""
maxlen = t.get_shape().as_list()[1] maxlen = t.get_shape().as_list()[1]
mask = tf.sequence_mask(length, maxlen=maxlen)
# Subtract 1 from length to prevent the perturbation from going on 'eos'
mask = tf.sequence_mask(length - 1, maxlen=maxlen)
mask = tf.expand_dims(tf.cast(mask, tf.float32), -1) mask = tf.expand_dims(tf.cast(mask, tf.float32), -1)
# shape(mask) = (batch, num_timesteps, 1) # shape(mask) = (batch, num_timesteps, 1)
return t * mask return t * mask
...@@ -175,32 +179,16 @@ def _mask_by_length(t, length): ...@@ -175,32 +179,16 @@ def _mask_by_length(t, length):
def _scale_l2(x, norm_length): def _scale_l2(x, norm_length):
# shape(x) = (batch, num_timesteps, d) # shape(x) = (batch, num_timesteps, d)
# Divide x by max(abs(x)) for a numerically stable L2 norm. # Divide x by max(abs(x)) for a numerically stable L2 norm.
# 2norm(x) = a * 2norm(x/a) # 2norm(x) = a * 2norm(x/a)
# Scale over the full sequence, dims (1, 2) # Scale over the full sequence, dims (1, 2)
alpha = tf.reduce_max(tf.abs(x), (1, 2), keep_dims=True) + 1e-12 alpha = tf.reduce_max(tf.abs(x), (1, 2), keep_dims=True) + 1e-12
l2_norm = alpha * tf.sqrt(tf.reduce_sum(tf.pow(x / alpha, 2), (1, 2), l2_norm = alpha * tf.sqrt(
keep_dims=True) + 1e-6) tf.reduce_sum(tf.pow(x / alpha, 2), (1, 2), keep_dims=True) + 1e-6)
x_unit = x / l2_norm x_unit = x / l2_norm
return norm_length * x_unit return norm_length * x_unit
def _end_of_seq_mask(tokens):
"""Generate a mask for the EOS token (1.0 on EOS, 0.0 otherwise).
Args:
tokens: 1-D integer tensor [num_timesteps*batch_size]. Each element is an
id from the vocab.
Returns:
Float tensor same shape as tokens, whose values are 1.0 on the end of
sequence and 0.0 on the others.
"""
eos_id = FLAGS.vocab_size - 1
return tf.cast(tf.equal(tokens, eos_id), tf.float32)
def _kl_divergence_with_logits(q_logits, p_logits, weights): def _kl_divergence_with_logits(q_logits, p_logits, weights):
"""Returns weighted KL divergence between distributions q and p. """Returns weighted KL divergence between distributions q and p.
...@@ -218,21 +206,20 @@ def _kl_divergence_with_logits(q_logits, p_logits, weights): ...@@ -218,21 +206,20 @@ def _kl_divergence_with_logits(q_logits, p_logits, weights):
# For logistic regression # For logistic regression
if FLAGS.num_classes == 2: if FLAGS.num_classes == 2:
q = tf.nn.sigmoid(q_logits) q = tf.nn.sigmoid(q_logits)
p = tf.nn.sigmoid(p_logits)
kl = (-tf.nn.sigmoid_cross_entropy_with_logits(logits=q_logits, labels=q) + kl = (-tf.nn.sigmoid_cross_entropy_with_logits(logits=q_logits, labels=q) +
tf.nn.sigmoid_cross_entropy_with_logits(logits=p_logits, labels=q)) tf.nn.sigmoid_cross_entropy_with_logits(logits=p_logits, labels=q))
kl = tf.squeeze(kl)
# For softmax regression # For softmax regression
else: else:
q = tf.nn.softmax(q_logits) q = tf.nn.softmax(q_logits)
p = tf.nn.softmax(p_logits) kl = tf.reduce_sum(
kl = tf.reduce_sum(q * (tf.log(q) - tf.log(p)), 1) q * (tf.nn.log_softmax(q_logits) - tf.nn.log_softmax(p_logits)), 1)
num_labels = tf.reduce_sum(weights) num_labels = tf.reduce_sum(weights)
num_labels = tf.where(tf.equal(num_labels, 0.), 1., num_labels) num_labels = tf.where(tf.equal(num_labels, 0.), 1., num_labels)
kl.get_shape().assert_has_rank(2) kl.get_shape().assert_has_rank(1)
weights.get_shape().assert_has_rank(1) weights.get_shape().assert_has_rank(1)
loss = tf.identity(tf.reduce_sum(tf.expand_dims(weights, -1) * kl) / loss = tf.identity(tf.reduce_sum(weights * kl) / num_labels, name='kl')
num_labels, name='kl')
return loss return loss
licenses(["notice"]) # Apache 2.0
package( package(
default_visibility = [ default_visibility = [
"//adversarial_text:__subpackages__", "//adversarial_text:__subpackages__",
...@@ -10,6 +12,7 @@ py_binary( ...@@ -10,6 +12,7 @@ py_binary(
deps = [ deps = [
":data_utils", ":data_utils",
":document_generators", ":document_generators",
# tensorflow dep,
], ],
) )
...@@ -19,17 +22,24 @@ py_binary( ...@@ -19,17 +22,24 @@ py_binary(
deps = [ deps = [
":data_utils", ":data_utils",
":document_generators", ":document_generators",
# tensorflow dep,
], ],
) )
py_library( py_library(
name = "document_generators", name = "document_generators",
srcs = ["document_generators.py"], srcs = ["document_generators.py"],
deps = [
# tensorflow dep,
],
) )
py_library( py_library(
name = "data_utils", name = "data_utils",
srcs = ["data_utils.py"], srcs = ["data_utils.py"],
deps = [
# tensorflow dep,
],
) )
py_test( py_test(
...@@ -37,5 +47,6 @@ py_test( ...@@ -37,5 +47,6 @@ py_test(
srcs = ["data_utils_test.py"], srcs = ["data_utils_test.py"],
deps = [ deps = [
":data_utils", ":data_utils",
# tensorflow dep,
], ],
) )
# Copyright 2017 Google, Inc. All Rights Reserved. # Copyright 2017 Google Inc. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,13 +12,15 @@ ...@@ -12,13 +12,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Utilities for generating/preprocessing data for adversarial text models.""" """Utilities for generating/preprocessing data for adversarial text models."""
import operator import operator
import os import os
import random import random
import re import re
# Dependency imports
import tensorflow as tf import tensorflow as tf
EOS_TOKEN = '</s>' EOS_TOKEN = '</s>'
...@@ -215,13 +217,17 @@ def build_lm_sequence(seq): ...@@ -215,13 +217,17 @@ def build_lm_sequence(seq):
Returns: Returns:
SequenceWrapper with `seq` tokens copied over to output sequence tokens and SequenceWrapper with `seq` tokens copied over to output sequence tokens and
labels (offset by 1, i.e. predict next token) with weights set to 1.0. labels (offset by 1, i.e. predict next token) with weights set to 1.0,
except for <eos> token.
""" """
lm_seq = SequenceWrapper() lm_seq = SequenceWrapper()
for i, timestep in enumerate(seq[:-1]): for i, timestep in enumerate(seq):
lm_seq.add_timestep().set_token(timestep.token).set_label( if i == len(seq) - 1:
seq[i + 1].token).set_weight(1.0) lm_seq.add_timestep().set_token(timestep.token).set_label(
seq[i].token).set_weight(0.0)
else:
lm_seq.add_timestep().set_token(timestep.token).set_label(
seq[i + 1].token).set_weight(1.0)
return lm_seq return lm_seq
......
# Copyright 2017 Google, Inc. All Rights Reserved. # Copyright 2017 Google Inc. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,12 +12,13 @@ ...@@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Tests for data_utils.""" """Tests for data_utils."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
# Dependency imports
import tensorflow as tf import tensorflow as tf
from adversarial_text.data import data_utils from adversarial_text.data import data_utils
...@@ -91,9 +92,16 @@ class DataUtilsTest(tf.test.TestCase): ...@@ -91,9 +92,16 @@ class DataUtilsTest(tf.test.TestCase):
seq = self._buildDummySequence() seq = self._buildDummySequence()
lm_seq = data.build_lm_sequence(seq) lm_seq = data.build_lm_sequence(seq)
for i, ts in enumerate(lm_seq): for i, ts in enumerate(lm_seq):
self.assertEqual(ts.token, i) # For end of sequence, the token and label should be same, and weight
self.assertEqual(ts.label, i + 1) # should be 0.0.
self.assertEqual(ts.weight, 1.0) if i == len(lm_seq) - 1:
self.assertEqual(ts.token, i)
self.assertEqual(ts.label, i)
self.assertEqual(ts.weight, 0.0)
else:
self.assertEqual(ts.token, i)
self.assertEqual(ts.label, i + 1)
self.assertEqual(ts.weight, 1.0)
def testBuildSAESeq(self): def testBuildSAESeq(self):
seq = self._buildDummySequence() seq = self._buildDummySequence()
......
# Copyright 2017 Google, Inc. All Rights Reserved. # Copyright 2017 Google Inc. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Input readers and document/token generators for datasets.""" """Input readers and document/token generators for datasets."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -23,6 +22,8 @@ import csv ...@@ -23,6 +22,8 @@ import csv
import os import os
import random import random
# Dependency imports
import tensorflow as tf import tensorflow as tf
from adversarial_text.data import data_utils from adversarial_text.data import data_utils
...@@ -60,7 +61,6 @@ flags.DEFINE_string('rcv1_input_dir', '', ...@@ -60,7 +61,6 @@ flags.DEFINE_string('rcv1_input_dir', '',
flags.DEFINE_string('rt_input_dir', '', flags.DEFINE_string('rt_input_dir', '',
'The Rotten Tomatoes dataset input directory.') 'The Rotten Tomatoes dataset input directory.')
# The amazon reviews input file to use in either the RT or IMDB datasets. # The amazon reviews input file to use in either the RT or IMDB datasets.
flags.DEFINE_string('amazon_unlabeled_input_file', '', flags.DEFINE_string('amazon_unlabeled_input_file', '',
'The unlabeled Amazon Reviews dataset input file. If set, ' 'The unlabeled Amazon Reviews dataset input file. If set, '
...@@ -211,8 +211,12 @@ def imdb_documents(dataset='train', ...@@ -211,8 +211,12 @@ def imdb_documents(dataset='train',
if FLAGS.amazon_unlabeled_input_file and include_unlabeled: if FLAGS.amazon_unlabeled_input_file and include_unlabeled:
with open(FLAGS.amazon_unlabeled_input_file) as rt_f: with open(FLAGS.amazon_unlabeled_input_file) as rt_f:
for content in rt_f: for content in rt_f:
yield Document(content=content, is_validation=False, is_test=False, yield Document(
label=None, add_tokens=False) content=content,
is_validation=False,
is_test=False,
label=None,
add_tokens=False)
def dbpedia_documents(dataset='train', def dbpedia_documents(dataset='train',
...@@ -265,7 +269,8 @@ def rcv1_documents(dataset='train', ...@@ -265,7 +269,8 @@ def rcv1_documents(dataset='train',
# pylint:disable=line-too-long # pylint:disable=line-too-long
"""Generates Documents for Reuters Corpus (rcv1) dataset. """Generates Documents for Reuters Corpus (rcv1) dataset.
Dataset described at http://www.ai.mit.edu/projects/jmlr/papers/volume5/lewis04a/lyrl2004_rcv1v2_README.htm Dataset described at
http://www.ai.mit.edu/projects/jmlr/papers/volume5/lewis04a/lyrl2004_rcv1v2_README.htm
Args: Args:
dataset: str, identifies the csv file within the rcv1 data directory. dataset: str, identifies the csv file within the rcv1 data directory.
...@@ -354,17 +359,25 @@ def rt_documents(dataset='train', ...@@ -354,17 +359,25 @@ def rt_documents(dataset='train',
if class_label is None: if class_label is None:
# Process Amazon Review data for unlabeled dataset # Process Amazon Review data for unlabeled dataset
if content.startswith('review/text'): if content.startswith('review/text'):
yield Document(content=content, is_validation=False, yield Document(
is_test=False, label=None, add_tokens=False) content=content,
is_validation=False,
is_test=False,
label=None,
add_tokens=False)
else: else:
# 10% of the data is randomly held out for the validation set and # 10% of the data is randomly held out for the validation set and
# another 10% of it is randomly held out for the test set # another 10% of it is randomly held out for the test set
random_int = random.randint(1, 10) random_int = random.randint(1, 10)
is_validation = random_int == 1 is_validation = random_int == 1
is_test = random_int == 2 is_test = random_int == 2
if (is_test and dataset != 'test') or ( if (is_test and dataset != 'test') or (is_validation and
is_validation and not include_validation): not include_validation):
continue continue
yield Document(content=content, is_validation=is_validation, yield Document(
is_test=is_test, label=class_label, add_tokens=True) content=content,
is_validation=is_validation,
is_test=is_test,
label=class_label,
add_tokens=True)
# Copyright 2017 Google, Inc. All Rights Reserved. # Copyright 2017 Google Inc. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Create TFRecord files of SequenceExample protos from dataset. """Create TFRecord files of SequenceExample protos from dataset.
Constructs 3 datasets: Constructs 3 datasets:
...@@ -31,6 +30,8 @@ from __future__ import print_function ...@@ -31,6 +30,8 @@ from __future__ import print_function
import os import os
import string import string
# Dependency imports
import tensorflow as tf import tensorflow as tf
from adversarial_text.data import data_utils from adversarial_text.data import data_utils
...@@ -197,6 +198,7 @@ def generate_test_data(vocab_ids, writer_lm_all, writer_seq_ae_all): ...@@ -197,6 +198,7 @@ def generate_test_data(vocab_ids, writer_lm_all, writer_seq_ae_all):
def main(_): def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
tf.logging.info('Assigning vocabulary ids...') tf.logging.info('Assigning vocabulary ids...')
vocab_ids = make_vocab_ids( vocab_ids = make_vocab_ids(
FLAGS.vocab_file or os.path.join(FLAGS.output_dir, 'vocab.txt')) FLAGS.vocab_file or os.path.join(FLAGS.output_dir, 'vocab.txt'))
......
# Copyright 2017 Google, Inc. All Rights Reserved. # Copyright 2017 Google Inc. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Generates vocabulary and term frequency files for datasets.""" """Generates vocabulary and term frequency files for datasets."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -20,6 +19,8 @@ from __future__ import print_function ...@@ -20,6 +19,8 @@ from __future__ import print_function
from collections import defaultdict from collections import defaultdict
# Dependency imports
import tensorflow as tf import tensorflow as tf
from adversarial_text.data import data_utils from adversarial_text.data import data_utils
...@@ -66,6 +67,7 @@ def fill_vocab_from_doc(doc, vocab_freqs, doc_counts): ...@@ -66,6 +67,7 @@ def fill_vocab_from_doc(doc, vocab_freqs, doc_counts):
def main(_): def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
vocab_freqs = defaultdict(int) vocab_freqs = defaultdict(int)
doc_counts = defaultdict(int) doc_counts = defaultdict(int)
......
# Copyright 2017 Google, Inc. All Rights Reserved. # Copyright 2017 Google Inc. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Evaluates text classification model.""" """Evaluates text classification model."""
from __future__ import absolute_import from __future__ import absolute_import
...@@ -22,6 +21,8 @@ from __future__ import print_function ...@@ -22,6 +21,8 @@ from __future__ import print_function
import math import math
import time import time
# Dependency imports
import tensorflow as tf import tensorflow as tf
import graphs import graphs
...@@ -100,6 +101,7 @@ def run_eval(eval_ops, summary_writer, saver): ...@@ -100,6 +101,7 @@ def run_eval(eval_ops, summary_writer, saver):
def _log_values(sess, value_ops, summary_writer=None): def _log_values(sess, value_ops, summary_writer=None):
"""Evaluate, log, and write summaries of the eval metrics in value_ops."""
metric_names, value_ops = zip(*value_ops.items()) metric_names, value_ops = zip(*value_ops.items())
values = sess.run(value_ops) values = sess.run(value_ops)
......
# Copyright 2017 Google, Inc. All Rights Reserved. # Copyright 2017 Google Inc. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Virtual adversarial text models.""" """Virtual adversarial text models."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -20,6 +19,9 @@ from __future__ import print_function ...@@ -20,6 +19,9 @@ from __future__ import print_function
import csv import csv
import os import os
# Dependency imports
import tensorflow as tf import tensorflow as tf
import adversarial_losses as adv_lib import adversarial_losses as adv_lib
...@@ -81,7 +83,8 @@ flags.DEFINE_integer('replicas_to_aggregate', 1, ...@@ -81,7 +83,8 @@ flags.DEFINE_integer('replicas_to_aggregate', 1,
# Regularization # Regularization
flags.DEFINE_float('max_grad_norm', 1.0, flags.DEFINE_float('max_grad_norm', 1.0,
'Clip the global gradient norm to this value.') 'Clip the global gradient norm to this value.')
flags.DEFINE_float('keep_prob_emb', 1.0, 'keep probability on embedding layer') flags.DEFINE_float('keep_prob_emb', 1.0, 'keep probability on embedding layer. '
'0.5 is optimal on IMDB with virtual adversarial training.')
flags.DEFINE_float('keep_prob_lstm_out', 1.0, flags.DEFINE_float('keep_prob_lstm_out', 1.0,
'keep probability on lstm output.') 'keep probability on lstm output.')
flags.DEFINE_float('keep_prob_cl_hidden', 1.0, flags.DEFINE_float('keep_prob_cl_hidden', 1.0,
...@@ -249,8 +252,7 @@ class VatxtModel(object): ...@@ -249,8 +252,7 @@ class VatxtModel(object):
eval_ops = { eval_ops = {
'accuracy': 'accuracy':
tf.contrib.metrics.streaming_accuracy( tf.contrib.metrics.streaming_accuracy(
layers_lib.predictions(logits), inputs.labels, layers_lib.predictions(logits), inputs.labels, inputs.weights)
inputs.weights)
} }
with tf.control_dependencies([inputs.save_state(next_state)]): with tf.control_dependencies([inputs.save_state(next_state)]):
...@@ -610,7 +612,8 @@ def _inputs(dataset='train', pretrain=False, bidir=False): ...@@ -610,7 +612,8 @@ def _inputs(dataset='train', pretrain=False, bidir=False):
state_size=FLAGS.rnn_cell_size, state_size=FLAGS.rnn_cell_size,
num_layers=FLAGS.rnn_num_layers, num_layers=FLAGS.rnn_num_layers,
batch_size=FLAGS.batch_size, batch_size=FLAGS.batch_size,
unroll_steps=FLAGS.num_timesteps) unroll_steps=FLAGS.num_timesteps,
eos_id=FLAGS.vocab_size - 1)
def _get_vocab_freqs(): def _get_vocab_freqs():
......
# Copyright 2017 Google, Inc. All Rights Reserved. # Copyright 2017 Google Inc. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Tests for graphs.""" """Tests for graphs."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -26,6 +25,8 @@ import shutil ...@@ -26,6 +25,8 @@ import shutil
import string import string
import tempfile import tempfile
# Dependency imports
import tensorflow as tf import tensorflow as tf
import graphs import graphs
......
# Copyright 2017 Google, Inc. All Rights Reserved. # Copyright 2017 Google Inc. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Input utils for virtual adversarial text classification.""" """Input utils for virtual adversarial text classification."""
from __future__ import absolute_import from __future__ import absolute_import
...@@ -20,6 +19,9 @@ from __future__ import division ...@@ -20,6 +19,9 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
# Dependency imports
import tensorflow as tf import tensorflow as tf
from adversarial_text.data import data_utils from adversarial_text.data import data_utils
...@@ -28,7 +30,12 @@ from adversarial_text.data import data_utils ...@@ -28,7 +30,12 @@ from adversarial_text.data import data_utils
class VatxtInput(object): class VatxtInput(object):
"""Wrapper around NextQueuedSequenceBatch.""" """Wrapper around NextQueuedSequenceBatch."""
def __init__(self, batch, state_name=None, tokens=None, num_states=0): def __init__(self,
batch,
state_name=None,
tokens=None,
num_states=0,
eos_id=None):
"""Construct VatxtInput. """Construct VatxtInput.
Args: Args:
...@@ -36,6 +43,7 @@ class VatxtInput(object): ...@@ -36,6 +43,7 @@ class VatxtInput(object):
state_name: str, name of state to fetch and save. state_name: str, name of state to fetch and save.
tokens: int Tensor, tokens. Defaults to batch's F_TOKEN_ID sequence. tokens: int Tensor, tokens. Defaults to batch's F_TOKEN_ID sequence.
num_states: int The number of states to store. num_states: int The number of states to store.
eos_id: int Id of end of Sequence.
""" """
self._batch = batch self._batch = batch
self._state_name = state_name self._state_name = state_name
...@@ -58,6 +66,14 @@ class VatxtInput(object): ...@@ -58,6 +66,14 @@ class VatxtInput(object):
l = tf.reshape(l, [-1]) l = tf.reshape(l, [-1])
self._labels = l self._labels = l
# eos weights
self._eos_weights = None
if eos_id:
ew = tf.cast(tf.equal(self._tokens, eos_id), tf.float32)
ew = tf.transpose(ew, [1, 0])
ew = tf.reshape(ew, [-1])
self._eos_weights = ew
@property @property
def tokens(self): def tokens(self):
return self._tokens return self._tokens
...@@ -66,6 +82,10 @@ class VatxtInput(object): ...@@ -66,6 +82,10 @@ class VatxtInput(object):
def weights(self): def weights(self):
return self._weights return self._weights
@property
def eos_weights(self):
return self._eos_weights
@property @property
def labels(self): def labels(self):
return self._labels return self._labels
...@@ -246,7 +266,8 @@ def inputs(data_dir=None, ...@@ -246,7 +266,8 @@ def inputs(data_dir=None,
state_size=None, state_size=None,
num_layers=0, num_layers=0,
batch_size=32, batch_size=32,
unroll_steps=100): unroll_steps=100,
eos_id=None):
"""Inputs for text model. """Inputs for text model.
Args: Args:
...@@ -260,7 +281,7 @@ def inputs(data_dir=None, ...@@ -260,7 +281,7 @@ def inputs(data_dir=None,
num_layers: int, the number of LSTM layers. num_layers: int, the number of LSTM layers.
batch_size: int, batch size. batch_size: int, batch size.
unroll_steps: int, number of timesteps to unroll for TBTT. unroll_steps: int, number of timesteps to unroll for TBTT.
eos_id: int, id of end of sequence. used for the kl weights on vat
Returns: Returns:
Instance of VatxtInput (x2 if bidir=True and pretrain=True, i.e. forward and Instance of VatxtInput (x2 if bidir=True and pretrain=True, i.e. forward and
reverse). reverse).
...@@ -280,9 +301,15 @@ def inputs(data_dir=None, ...@@ -280,9 +301,15 @@ def inputs(data_dir=None,
state_size, num_layers, unroll_steps, state_size, num_layers, unroll_steps,
batch_size) batch_size)
forward_input = VatxtInput( forward_input = VatxtInput(
forward_batch, state_name=state_name, num_states=num_layers) forward_batch,
state_name=state_name,
num_states=num_layers,
eos_id=eos_id)
reverse_input = VatxtInput( reverse_input = VatxtInput(
reverse_batch, state_name=state_name_rev, num_states=num_layers) reverse_batch,
state_name=state_name_rev,
num_states=num_layers,
eos_id=eos_id)
return forward_input, reverse_input return forward_input, reverse_input
elif bidir: elif bidir:
...@@ -322,4 +349,5 @@ def inputs(data_dir=None, ...@@ -322,4 +349,5 @@ def inputs(data_dir=None,
unroll_steps, unroll_steps,
batch_size, batch_size,
bidir_input=False) bidir_input=False)
return VatxtInput(batch, state_name=state_name, num_states=num_layers) return VatxtInput(
batch, state_name=state_name, num_states=num_layers, eos_id=eos_id)
# Copyright 2017 Google, Inc. All Rights Reserved. # Copyright 2017 Google Inc. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,14 +12,14 @@ ...@@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Layers for VatxtModel.""" """Layers for VatxtModel."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import tensorflow as tf # Dependency imports
import tensorflow as tf
K = tf.contrib.keras K = tf.contrib.keras
...@@ -34,7 +34,7 @@ def cl_logits_subgraph(layer_sizes, input_size, num_classes, keep_prob=1.): ...@@ -34,7 +34,7 @@ def cl_logits_subgraph(layer_sizes, input_size, num_classes, keep_prob=1.):
subgraph.add(K.layers.Dense(layer_size, activation='relu')) subgraph.add(K.layers.Dense(layer_size, activation='relu'))
if keep_prob < 1.: if keep_prob < 1.:
subgraph.add(K.layers.Dropout(keep_prob)) subgraph.add(K.layers.Dropout(1. - keep_prob))
subgraph.add(K.layers.Dense(1 if num_classes == 2 else num_classes)) subgraph.add(K.layers.Dense(1 if num_classes == 2 else num_classes))
return subgraph return subgraph
...@@ -76,7 +76,14 @@ class Embedding(K.layers.Layer): ...@@ -76,7 +76,14 @@ class Embedding(K.layers.Layer):
def call(self, x): def call(self, x):
embedded = tf.nn.embedding_lookup(self.var, x) embedded = tf.nn.embedding_lookup(self.var, x)
if self.keep_prob < 1.: if self.keep_prob < 1.:
embedded = tf.nn.dropout(embedded, self.keep_prob) shape = embedded.get_shape().as_list()
# Use same dropout masks at each timestep with specifying noise_shape.
# This slightly improves performance.
# Please see https://arxiv.org/abs/1512.05287 for the theoretical
# explanation.
embedded = tf.nn.dropout(
embedded, self.keep_prob, noise_shape=(shape[0], 1, shape[2]))
return embedded return embedded
def _normalize(self, emb): def _normalize(self, emb):
...@@ -153,11 +160,11 @@ class SoftmaxLoss(K.layers.Layer): ...@@ -153,11 +160,11 @@ class SoftmaxLoss(K.layers.Layer):
self.lin_w = self.add_weight( self.lin_w = self.add_weight(
shape=(input_shape[-1], self.vocab_size), shape=(input_shape[-1], self.vocab_size),
name='lm_lin_w', name='lm_lin_w',
initializer='glorot_uniform') initializer=K.initializers.glorot_uniform())
self.lin_b = self.add_weight( self.lin_b = self.add_weight(
shape=(self.vocab_size,), shape=(self.vocab_size,),
name='lm_lin_b', name='lm_lin_b',
initializer='glorot_uniform') initializer=K.initializers.glorot_uniform())
super(SoftmaxLoss, self).build(input_shape) super(SoftmaxLoss, self).build(input_shape)
......
# Copyright 2017 Google, Inc. All Rights Reserved. # Copyright 2017 Google Inc. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,18 +12,19 @@ ...@@ -12,18 +12,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Pretrains a recurrent language model. """Pretrains a recurrent language model.
Computational time: Computational time:
5 days to train 100000 steps on 1 layer 1024 hidden units LSTM, 2 days to train 100000 steps on 1 layer 1024 hidden units LSTM,
256 embeddings, 400 truncated BP, 64 minibatch and on 4 GPU with 256 embeddings, 400 truncated BP, 256 minibatch and on single GPU (Pascal
SyncReplicasOptimizer, that is the total minibatch is 256. Titan X, cuDNNv5).
""" """
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
# Dependency imports
import tensorflow as tf import tensorflow as tf
import graphs import graphs
......
# Copyright 2017 Google, Inc. All Rights Reserved. # Copyright 2017 Google Inc. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,17 +12,16 @@ ...@@ -12,17 +12,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Trains LSTM text classification model. """Trains LSTM text classification model.
Model trains with adversarial or virtual adversarial training. Model trains with adversarial or virtual adversarial training.
Computational time: Computational time:
6 hours to train 10000 steps without adversarial or virtual adversarial 1.8 hours to train 10000 steps without adversarial or virtual adversarial
training, on 1 layer 1024 hidden units LSTM, 256 embeddings, 400 truncated training, on 1 layer 1024 hidden units LSTM, 256 embeddings, 400 truncated
BP, 64 minibatch and on single GPU. BP, 64 minibatch and on single GPU (Pascal Titan X, cuDNNv5).
12 hours to train 10000 steps with adversarial or virtual adversarial 4 hours to train 10000 steps with adversarial or virtual adversarial
training, with above condition. training, with above condition.
To initialize embedding and LSTM cell weights from a pretrained model, set To initialize embedding and LSTM cell weights from a pretrained model, set
...@@ -32,6 +31,8 @@ from __future__ import absolute_import ...@@ -32,6 +31,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
# Dependency imports
import tensorflow as tf import tensorflow as tf
import graphs import graphs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment