"examples/hubert/lightning_modules.py" did not exist on "c5bd93b6c274664167a1118e7e14c71920d7fad2"
Commit 27b4acd4 authored by Aman Gupta's avatar Aman Gupta
Browse files

Merge remote-tracking branch 'upstream/master'

parents 5133522f d4e1f97f
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Dependency parsing module."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from corpus_processing import minibatching
from model import model_helpers
from model import task_module
class DepparseModule(task_module.SemiSupervisedModule):
def __init__(self, config, task_name, n_classes, inputs, encoder):
super(DepparseModule, self).__init__()
self.task_name = task_name
self.n_classes = n_classes
self.labels = labels = tf.placeholder(tf.float32, [None, None, None],
name=task_name + '_labels')
class PredictionModule(object):
def __init__(self, name, dep_reprs, head_reprs, roll_direction=0):
self.name = name
with tf.variable_scope(name + '/predictions'):
# apply hidden layers to the input representations
arc_dep_hidden = model_helpers.project(
dep_reprs, config.projection_size, 'arc_dep_hidden')
arc_head_hidden = model_helpers.project(
head_reprs, config.projection_size, 'arc_head_hidden')
arc_dep_hidden = tf.nn.relu(arc_dep_hidden)
arc_head_hidden = tf.nn.relu(arc_head_hidden)
arc_head_hidden = tf.nn.dropout(arc_head_hidden, inputs.keep_prob)
arc_dep_hidden = tf.nn.dropout(arc_dep_hidden, inputs.keep_prob)
# bilinear classifier excluding the final dot product
arc_head = tf.layers.dense(
arc_head_hidden, config.depparse_projection_size, name='arc_head')
W = tf.get_variable('shared_W',
shape=[config.projection_size, n_classes,
config.depparse_projection_size])
Wr = tf.get_variable('relation_specific_W',
shape=[config.projection_size,
config.depparse_projection_size])
Wr_proj = tf.tile(tf.expand_dims(Wr, axis=-2), [1, n_classes, 1])
W += Wr_proj
arc_dep = tf.tensordot(arc_dep_hidden, W, axes=[[-1], [0]])
shape = tf.shape(arc_dep)
arc_dep = tf.reshape(arc_dep,
[shape[0], -1, config.depparse_projection_size])
# apply the transformer scaling trick to prevent dot products from
# getting too large (possibly not necessary)
scale = np.power(
config.depparse_projection_size, 0.25).astype('float32')
scale = tf.get_variable('scale', initializer=scale, dtype=tf.float32)
arc_dep /= scale
arc_head /= scale
# compute the scores for each candidate arc
word_scores = tf.matmul(arc_head, arc_dep, transpose_b=True)
root_scores = tf.layers.dense(arc_head, n_classes, name='root_score')
arc_scores = tf.concat([root_scores, word_scores], axis=-1)
# disallow the model from making impossible predictions
mask = inputs.mask
mask_shape = tf.shape(mask)
mask = tf.tile(tf.expand_dims(mask, -1), [1, 1, n_classes])
mask = tf.reshape(mask, [-1, mask_shape[1] * n_classes])
mask = tf.concat([tf.ones((mask_shape[0], 1)),
tf.zeros((mask_shape[0], n_classes - 1)), mask],
axis=1)
mask = tf.tile(tf.expand_dims(mask, 1), [1, mask_shape[1], 1])
arc_scores += (mask - 1) * 100.0
self.logits = arc_scores
self.loss = model_helpers.masked_ce_loss(
self.logits, labels, inputs.mask,
roll_direction=roll_direction)
primary = PredictionModule(
'primary',
[encoder.uni_reprs, encoder.bi_reprs],
[encoder.uni_reprs, encoder.bi_reprs])
ps = [
PredictionModule(
'full',
[encoder.uni_reprs, encoder.bi_reprs],
[encoder.uni_reprs, encoder.bi_reprs]),
PredictionModule('fw_fw', [encoder.uni_fw], [encoder.uni_fw]),
PredictionModule('fw_bw', [encoder.uni_fw], [encoder.uni_bw]),
PredictionModule('bw_fw', [encoder.uni_bw], [encoder.uni_fw]),
PredictionModule('bw_bw', [encoder.uni_bw], [encoder.uni_bw]),
]
self.unsupervised_loss = sum(p.loss for p in ps)
self.supervised_loss = primary.loss
self.probs = tf.nn.softmax(primary.logits)
self.preds = tf.argmax(primary.logits, axis=-1)
def update_feed_dict(self, feed, mb):
if self.task_name in mb.teacher_predictions:
feed[self.labels] = mb.teacher_predictions[self.task_name]
elif mb.task_name != 'unlabeled':
labels = minibatching.build_array(
[[0] + e.labels + [0] for e in mb.examples])
feed[self.labels] = np.eye(
(1 + mb.words.shape[1]) * self.n_classes)[labels]
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Dependency parsing evaluation (computes UAS/LAS)."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from task_specific.word_level import word_level_scorer
class DepparseScorer(word_level_scorer.WordLevelScorer):
def __init__(self, n_relations, punctuation):
super(DepparseScorer, self).__init__()
self._n_relations = n_relations
self._punctuation = punctuation if punctuation else None
def _get_results(self):
correct_unlabeled, correct_labeled, count = 0, 0, 0
for example, preds in zip(self._examples, self._preds):
for w, y_true, y_pred in zip(example.words[1:-1], example.labels, preds):
if w in self._punctuation:
continue
count += 1
correct_labeled += (1 if y_pred == y_true else 0)
correct_unlabeled += (1 if int(y_pred // self._n_relations) ==
int(y_true // self._n_relations) else 0)
return [
("las", 100.0 * correct_labeled / count),
("uas", 100.0 * correct_unlabeled / count),
("loss", self.get_loss()),
]
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Sequence tagging module."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from corpus_processing import minibatching
from model import model_helpers
from model import task_module
class TaggingModule(task_module.SemiSupervisedModule):
def __init__(self, config, task_name, n_classes, inputs,
encoder):
super(TaggingModule, self).__init__()
self.task_name = task_name
self.n_classes = n_classes
self.labels = labels = tf.placeholder(tf.float32, [None, None, None],
name=task_name + '_labels')
class PredictionModule(object):
def __init__(self, name, input_reprs, roll_direction=0, activate=True):
self.name = name
with tf.variable_scope(name + '/predictions'):
projected = model_helpers.project(input_reprs, config.projection_size)
if activate:
projected = tf.nn.relu(projected)
self.logits = tf.layers.dense(projected, n_classes, name='predict')
targets = labels
targets *= (1 - inputs.label_smoothing)
targets += inputs.label_smoothing / n_classes
self.loss = model_helpers.masked_ce_loss(
self.logits, targets, inputs.mask, roll_direction=roll_direction)
primary = PredictionModule('primary',
([encoder.uni_reprs, encoder.bi_reprs]))
ps = [
PredictionModule('full', ([encoder.uni_reprs, encoder.bi_reprs]),
activate=False),
PredictionModule('forwards', [encoder.uni_fw]),
PredictionModule('backwards', [encoder.uni_bw]),
PredictionModule('future', [encoder.uni_fw], roll_direction=1),
PredictionModule('past', [encoder.uni_bw], roll_direction=-1),
]
self.unsupervised_loss = sum(p.loss for p in ps)
self.supervised_loss = primary.loss
self.probs = tf.nn.softmax(primary.logits)
self.preds = tf.argmax(primary.logits, axis=-1)
def update_feed_dict(self, feed, mb):
if self.task_name in mb.teacher_predictions:
feed[self.labels] = mb.teacher_predictions[self.task_name]
elif mb.task_name != 'unlabeled':
labels = minibatching.build_array(
[[0] + e.labels + [0] for e in mb.examples])
feed[self.labels] = np.eye(self.n_classes)[labels]
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Sequence tagging evaluation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
from task_specific.word_level import tagging_utils
from task_specific.word_level import word_level_scorer
class AccuracyScorer(word_level_scorer.WordLevelScorer):
def __init__(self, auto_fail_label=None):
super(AccuracyScorer, self).__init__()
self._auto_fail_label = auto_fail_label
def _get_results(self):
correct, count = 0, 0
for example, preds in zip(self._examples, self._preds):
for y_true, y_pred in zip(example.labels, preds):
count += 1
correct += (1 if y_pred == y_true and y_true != self._auto_fail_label
else 0)
return [
("accuracy", 100.0 * correct / count),
("loss", self.get_loss())
]
class F1Scorer(word_level_scorer.WordLevelScorer):
__metaclass__ = abc.ABCMeta
def __init__(self):
super(F1Scorer, self).__init__()
self._n_correct, self._n_predicted, self._n_gold = 0, 0, 0
def _get_results(self):
if self._n_correct == 0:
p, r, f1 = 0, 0, 0
else:
p = 100.0 * self._n_correct / self._n_predicted
r = 100.0 * self._n_correct / self._n_gold
f1 = 2 * p * r / (p + r)
return [
("precision", p),
("recall", r),
("f1", f1),
("loss", self.get_loss()),
]
class EntityLevelF1Scorer(F1Scorer):
def __init__(self, label_mapping):
super(EntityLevelF1Scorer, self).__init__()
self._inv_label_mapping = {v: k for k, v in label_mapping.iteritems()}
def _get_results(self):
self._n_correct, self._n_predicted, self._n_gold = 0, 0, 0
for example, preds in zip(self._examples, self._preds):
sent_spans = set(tagging_utils.get_span_labels(
example.labels, self._inv_label_mapping))
span_preds = set(tagging_utils.get_span_labels(
preds, self._inv_label_mapping))
self._n_correct += len(sent_spans & span_preds)
self._n_gold += len(sent_spans)
self._n_predicted += len(span_preds)
return super(EntityLevelF1Scorer, self)._get_results()
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for sequence tagging tasks for entity-level tasks (e.g., NER)."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
def get_span_labels(sentence_tags, inv_label_mapping=None):
"""Go from token-level labels to list of entities (start, end, class)."""
if inv_label_mapping:
sentence_tags = [inv_label_mapping[i] for i in sentence_tags]
span_labels = []
last = 'O'
start = -1
for i, tag in enumerate(sentence_tags):
pos, _ = (None, 'O') if tag == 'O' else tag.split('-')
if (pos == 'S' or pos == 'B' or tag == 'O') and last != 'O':
span_labels.append((start, i - 1, last.split('-')[-1]))
if pos == 'B' or pos == 'S' or last == 'O':
start = i
last = tag
if sentence_tags[-1] != 'O':
span_labels.append((start, len(sentence_tags) - 1,
sentence_tags[-1].split('-')[-1]))
return span_labels
def get_tags(span_labels, length, encoding):
"""Converts a list of entities to token-label labels based on the provided
encoding (e.g., BIOES).
"""
tags = ['O' for _ in range(length)]
for s, e, t in span_labels:
for i in range(s, e + 1):
tags[i] = 'I-' + t
if 'E' in encoding:
tags[e] = 'E-' + t
if 'B' in encoding:
tags[s] = 'B-' + t
if 'S' in encoding and s == e:
tags[s] = 'S-' + t
return tags
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for processing word-level datasets."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import os
import random
import tensorflow as tf
from base import embeddings
from base import utils
from corpus_processing import example
from corpus_processing import minibatching
from task_specific.word_level import tagging_utils
class TaggedDataLoader(object):
def __init__(self, config, name, is_token_level):
self._config = config
self._task_name = name
self._raw_data_path = os.path.join(config.raw_data_topdir, name)
self._is_token_level = is_token_level
self.label_mapping_path = os.path.join(
config.preprocessed_data_topdir,
(name if is_token_level else
name + '_' + config.label_encoding) + '_label_mapping.pkl')
if self.label_mapping:
self._n_classes = len(set(self.label_mapping.values()))
else:
self._n_classes = None
def get_dataset(self, split):
if (split == 'train' and not self._config.for_preprocessing and
tf.gfile.Exists(os.path.join(self._raw_data_path, 'train_subset.txt'))):
split = 'train_subset'
return minibatching.Dataset(
self._config, self._get_examples(split), self._task_name)
def get_labeled_sentences(self, split):
sentences = []
path = os.path.join(self._raw_data_path, split + '.txt')
if not tf.gfile.Exists(path):
if self._config.for_preprocessing:
return []
else:
raise ValueError('Unable to load data from', path)
with tf.gfile.GFile(path, 'r') as f:
sentence = []
for line in f:
line = line.strip().split()
if not line:
if sentence:
words, tags = zip(*sentence)
sentences.append((words, tags))
sentence = []
continue
if line[0] == '-DOCSTART-':
continue
word, tag = line[0], line[-1]
sentence.append((word, tag))
return sentences
@property
def label_mapping(self):
if not self._config.for_preprocessing:
return utils.load_cpickle(self.label_mapping_path)
tag_counts = collections.Counter()
train_tags = set()
for split in ['train', 'dev', 'test']:
for words, tags in self.get_labeled_sentences(split):
if not self._is_token_level:
span_labels = tagging_utils.get_span_labels(tags)
tags = tagging_utils.get_tags(
span_labels, len(words), self._config.label_encoding)
for tag in tags:
if self._task_name == 'depparse':
tag = tag.split('-')[1]
tag_counts[tag] += 1
if split == 'train':
train_tags.add(tag)
if self._task_name == 'ccg':
# for CCG, there are tags in the test sets that aren't in the train set
# all tags not in the train set get mapped to a special label
# the model will never predict this label because it never sees it in the
# training set
not_in_train_tags = []
for tag, count in tag_counts.items():
if tag not in train_tags:
not_in_train_tags.append(tag)
label_mapping = {
label: i for i, label in enumerate(sorted(filter(
lambda t: t not in not_in_train_tags, tag_counts.keys())))
}
n = len(label_mapping)
for tag in not_in_train_tags:
label_mapping[tag] = n
else:
labels = sorted(tag_counts.keys())
if self._task_name == 'depparse':
labels.remove('root')
labels.insert(0, 'root')
label_mapping = {label: i for i, label in enumerate(labels)}
return label_mapping
def _get_examples(self, split):
word_vocab = embeddings.get_word_vocab(self._config)
char_vocab = embeddings.get_char_vocab()
examples = [
TaggingExample(
self._config, self._is_token_level, words, tags,
word_vocab, char_vocab, self.label_mapping, self._task_name)
for words, tags in self.get_labeled_sentences(split)]
if self._config.train_set_percent < 100:
utils.log('using reduced train set ({:}%)'.format(
self._config.train_set_percent))
random.shuffle(examples)
examples = examples[:int(len(examples) *
self._config.train_set_percent / 100.0)]
return examples
class TaggingExample(example.Example):
def __init__(self, config, is_token_level, words, original_tags,
word_vocab, char_vocab, label_mapping, task_name):
super(TaggingExample, self).__init__(words, word_vocab, char_vocab)
if is_token_level:
labels = original_tags
else:
span_labels = tagging_utils.get_span_labels(original_tags)
labels = tagging_utils.get_tags(
span_labels, len(words), config.label_encoding)
if task_name == 'depparse':
self.labels = []
for l in labels:
split = l.split('-')
self.labels.append(
len(label_mapping) * (0 if split[0] == '0' else 1 + int(split[0]))
+ label_mapping[split[1]])
else:
self.labels = [label_mapping[l] for l in labels]
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base class for word-level scorers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
from corpus_processing import scorer
class WordLevelScorer(scorer.Scorer):
__metaclass__ = abc.ABCMeta
def __init__(self):
super(WordLevelScorer, self).__init__()
self._total_loss = 0
self._total_words = 0
self._examples = []
self._preds = []
def update(self, examples, predictions, loss):
super(WordLevelScorer, self).update(examples, predictions, loss)
n_words = 0
for example, preds in zip(examples, predictions):
self._examples.append(example)
self._preds.append(list(preds)[1:len(example.words) - 1])
n_words += len(example.words) - 2
self._total_loss += loss * n_words
self._total_words += n_words
def get_loss(self):
return self._total_loss / max(1, self._total_words)
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Runs training for CVT text models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import bisect
import time
import numpy as np
import tensorflow as tf
from base import utils
from model import multitask_model
from task_specific import task_definitions
class Trainer(object):
def __init__(self, config):
self._config = config
self.tasks = [task_definitions.get_task(self._config, task_name)
for task_name in self._config.task_names]
utils.log('Loading Pretrained Embeddings')
pretrained_embeddings = utils.load_cpickle(self._config.word_embeddings)
utils.log('Building Model')
self._model = multitask_model.Model(
self._config, pretrained_embeddings, self.tasks)
utils.log()
def train(self, sess, progress, summary_writer):
heading = lambda s: utils.heading(s, '(' + self._config.model_name + ')')
trained_on_sentences = 0
start_time = time.time()
unsupervised_loss_total, unsupervised_loss_count = 0, 0
supervised_loss_total, supervised_loss_count = 0, 0
for mb in self._get_training_mbs(progress.unlabeled_data_reader):
if mb.task_name != 'unlabeled':
loss = self._model.train_labeled(sess, mb)
supervised_loss_total += loss
supervised_loss_count += 1
if mb.task_name == 'unlabeled':
self._model.run_teacher(sess, mb)
loss = self._model.train_unlabeled(sess, mb)
unsupervised_loss_total += loss
unsupervised_loss_count += 1
mb.teacher_predictions.clear()
trained_on_sentences += mb.size
global_step = self._model.get_global_step(sess)
if global_step % self._config.print_every == 0:
utils.log('step {:} - '
'supervised loss: {:.2f} - '
'unsupervised loss: {:.2f} - '
'{:.1f} sentences per second'.format(
global_step,
supervised_loss_total / max(1, supervised_loss_count),
unsupervised_loss_total / max(1, unsupervised_loss_count),
trained_on_sentences / (time.time() - start_time)))
unsupervised_loss_total, unsupervised_loss_count = 0, 0
supervised_loss_total, supervised_loss_count = 0, 0
if global_step % self._config.eval_dev_every == 0:
heading('EVAL ON DEV')
self.evaluate_all_tasks(sess, summary_writer, progress.history)
progress.save_if_best_dev_model(sess, global_step)
utils.log()
if global_step % self._config.eval_train_every == 0:
heading('EVAL ON TRAIN')
self.evaluate_all_tasks(sess, summary_writer, progress.history, True)
utils.log()
if global_step % self._config.save_model_every == 0:
heading('CHECKPOINTING MODEL')
progress.write(sess, global_step)
utils.log()
def evaluate_all_tasks(self, sess, summary_writer, history, train_set=False):
for task in self.tasks:
results = self._evaluate_task(sess, task, summary_writer, train_set)
if history is not None:
results.append(('step', self._model.get_global_step(sess)))
history.append(results)
if history is not None:
utils.write_cpickle(history, self._config.history_file)
def _evaluate_task(self, sess, task, summary_writer, train_set):
scorer = task.get_scorer()
data = task.train_set if train_set else task.val_set
for i, mb in enumerate(data.get_minibatches(self._config.test_batch_size)):
loss, batch_preds = self._model.test(sess, mb)
scorer.update(mb.examples, batch_preds, loss)
results = scorer.get_results(task.name +
('_train_' if train_set else '_dev_'))
utils.log(task.name.upper() + ': ' + scorer.results_str())
write_summary(summary_writer, results,
global_step=self._model.get_global_step(sess))
return results
def _get_training_mbs(self, unlabeled_data_reader):
datasets = [task.train_set for task in self.tasks]
weights = [np.sqrt(dataset.size) for dataset in datasets]
thresholds = np.cumsum([w / np.sum(weights) for w in weights])
labeled_mbs = [dataset.endless_minibatches(self._config.train_batch_size)
for dataset in datasets]
unlabeled_mbs = unlabeled_data_reader.endless_minibatches()
while True:
dataset_ind = bisect.bisect(thresholds, np.random.random())
yield next(labeled_mbs[dataset_ind])
if self._config.is_semisup:
yield next(unlabeled_mbs)
def write_summary(writer, results, global_step):
for k, v in results:
if 'f1' in k or 'acc' in k or 'loss' in k:
writer.add_summary(tf.Summary(
value=[tf.Summary.Value(tag=k, simple_value=v)]), global_step)
writer.flush()
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Tracks and saves training progress (models and other data such as the current
location in the lm1b corpus) for later reloading.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from base import utils
from corpus_processing import unlabeled_data
class TrainingProgress(object):
def __init__(self, config, sess, checkpoint_saver, best_model_saver,
restore_if_possible=True):
self.config = config
self.checkpoint_saver = checkpoint_saver
self.best_model_saver = best_model_saver
tf.gfile.MakeDirs(config.checkpoints_dir)
if restore_if_possible and tf.gfile.Exists(config.progress):
history, current_file, current_line = utils.load_cpickle(
config.progress, memoized=False)
self.history = history
self.unlabeled_data_reader = unlabeled_data.UnlabeledDataReader(
config, current_file, current_line)
utils.log("Continuing from global step", dict(self.history[-1])["step"],
"(lm1b file {:}, line {:})".format(current_file, current_line))
self.checkpoint_saver.restore(sess, tf.train.latest_checkpoint(
self.config.checkpoints_dir))
else:
utils.log("No previous checkpoint found - starting from scratch")
self.history = []
self.unlabeled_data_reader = (
unlabeled_data.UnlabeledDataReader(config))
def write(self, sess, global_step):
self.checkpoint_saver.save(sess, self.config.checkpoint,
global_step=global_step)
utils.write_cpickle(
(self.history, self.unlabeled_data_reader.current_file,
self.unlabeled_data_reader.current_line),
self.config.progress)
def save_if_best_dev_model(self, sess, global_step):
best_avg_score = 0
for i, results in enumerate(self.history):
if any("train" in metric for metric, value in results):
continue
total, count = 0, 0
for metric, value in results:
if "f1" in metric or "las" in metric or "accuracy" in metric:
total += value
count += 1
avg_score = total / count
if avg_score >= best_avg_score:
best_avg_score = avg_score
if i == len(self.history) - 1:
utils.log("New best model! Saving...")
self.best_model_saver.save(sess, self.config.best_model_checkpoint,
global_step=global_step)
*.pkl binary
*.tfrecord binary
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
# Translations
*.mo
*.pot
# Django stuff:
*.log
.static_storage/
.media/
local_settings.py
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
......@@ -8,28 +8,42 @@ Chris J. Maddison\*, Dieterich Lawson\*, George Tucker\*, Nicolas Heess, Mohamma
This code implements 3 different bounds for training sequential latent variable models: the evidence lower bound (ELBO), the importance weighted auto-encoder bound (IWAE), and our bound, the filtering variational objective (FIVO).
Additionally it contains an implementation of the variational recurrent neural network (VRNN), a sequential latent variable model that can be trained using these three objectives. This repo provides code for training a VRNN to do sequence modeling of pianoroll and speech data.
Additionally it contains several sequential latent variable model implementations:
* Variational recurrent neural network (VRNN)
* Stochastic recurrent neural network (SRNN)
* Gaussian hidden Markov model with linear conditionals (GHMM)
The VRNN and SRNN can be trained for sequence modeling of pianoroll and speech data. The GHMM is trainable on a synthetic dataset, useful as a simple example of an analytically tractable model.
#### Directory Structure
The important parts of the code are organized as follows.
```
fivo.py # main script, contains flag definitions
runners.py # graph construction code for training and evaluation
bounds.py # code for computing each bound
data
├── datasets.py # readers for pianoroll and speech datasets
├── calculate_pianoroll_mean.py # preprocesses the pianoroll datasets
└── create_timit_dataset.py # preprocesses the TIMIT dataset
models
└── vrnn.py # variational RNN implementation
run_fivo.py # main script, contains flag definitions
fivo
├─smc.py # a sequential Monte Carlo implementation
├─bounds.py # code for computing each bound, uses smc.py
├─runners.py # code for VRNN and SRNN training and evaluation
├─ghmm_runners.py # code for GHMM training and evaluation
├─data
| ├─datasets.py # readers for pianoroll and speech datasets
| ├─calculate_pianoroll_mean.py # preprocesses the pianoroll datasets
| └─create_timit_dataset.py # preprocesses the TIMIT dataset
└─models
├─base.py # base classes used in other models
├─vrnn.py # VRNN implementation
├─srnn.py # SRNN implementation
└─ghmm.py # Gaussian hidden Markov model (GHMM) implementation
bin
├── run_train.sh # an example script that runs training
├── run_eval.sh # an example script that runs evaluation
└── download_pianorolls.sh # a script that downloads the pianoroll files
├─run_train.sh # an example script that runs training
├─run_eval.sh # an example script that runs evaluation
├─run_sample.sh # an example script that runs sampling
├─run_tests.sh # a script that runs all tests
└─download_pianorolls.sh # a script that downloads pianoroll files
```
### Training on Pianorolls
### Pianorolls
Requirements before we start:
......@@ -60,9 +74,9 @@ python data/calculate_pianoroll_mean.py --in_file=$PIANOROLL_DIR/jsb.pkl
#### Training
Now we can train a model. Here is a standard training run, taken from `bin/run_train.sh`:
Now we can train a model. Here is the command for a standard training run, taken from `bin/run_train.sh`:
```
python fivo.py \
python run_fivo.py \
--mode=train \
--logdir=/tmp/fivo \
--model=vrnn \
......@@ -75,26 +89,24 @@ python fivo.py \
--dataset_type="pianoroll"
```
You should see output that looks something like this (with a lot of extra logging cruft):
You should see output that looks something like this (with extra logging cruft):
```
Step 1, fivo bound per timestep: -11.801050
global_step/sec: 9.89825
Step 101, fivo bound per timestep: -11.198309
global_step/sec: 9.55475
Step 201, fivo bound per timestep: -11.287262
global_step/sec: 9.68146
step 301, fivo bound per timestep: -11.316490
global_step/sec: 9.94295
Step 401, fivo bound per timestep: -11.151743
Saving checkpoints for 0 into /tmp/fivo/model.ckpt.
Step 1, fivo bound per timestep: -11.322491
global_step/sec: 7.49971
Step 101, fivo bound per timestep: -11.399275
global_step/sec: 8.04498
Step 201, fivo bound per timestep: -11.174991
global_step/sec: 8.03989
Step 301, fivo bound per timestep: -11.073008
```
You will also see lines saying `Out of range: exceptions.StopIteration: Iteration finished`. This is not an error and is fine.
#### Evaluation
You can also evaluate saved checkpoints. The `eval` mode loads a model checkpoint, tests its performance on all items in a dataset, and reports the log-likelihood averaged over the dataset. For example here is a command, taken from `bin/run_eval.sh`, that will evaluate a JSB model on the test set:
```
python fivo.py \
python run_fivo.py \
--mode=eval \
--split=test \
--alsologtostderr \
......@@ -108,12 +120,52 @@ python fivo.py \
You should see output like this:
```
Model restored from step 1, evaluating.
test elbo ll/t: -12.299635, iwae ll/t: -12.128336 fivo ll/t: -11.656939
test elbo ll/seq: -754.750312, iwae ll/seq: -744.238773 fivo ll/seq: -715.3121490
Restoring parameters from /tmp/fivo/model.ckpt-0
Model restored from step 0, evaluating.
test elbo ll/t: -12.198834, iwae ll/t: -11.981187 fivo ll/t: -11.579776
test elbo ll/seq: -748.564789, iwae ll/seq: -735.209206 fivo ll/seq: -710.577141
```
The evaluation script prints log-likelihood in both nats per timestep (ll/t) and nats per sequence (ll/seq) for all three bounds.
#### Sampling
You can also sample from trained models. The `sample` mode loads a model checkpoint, conditions the model on a prefix of a randomly chosen datapoint, samples a sequence of outputs from the conditioned model, and writes out the samples and prefix to a `.npz` file in `logdir`. For example here is a command that samples from a model trained on JSB, taken from `bin/run_sample.sh`:
```
python run_fivo.py \
--mode=sample \
--alsologtostderr \
--logdir="/tmp/fivo" \
--model=vrnn \
--bound=fivo \
--batch_size=4 \
--num_samples=4 \
--split=test \
--dataset_path="$PIANOROLL_DIR/jsb.pkl" \
--dataset_type="pianoroll" \
--prefix_length=25 \
--sample_length=50
```
Here `num_samples` denotes the number of samples used when conditioning the model as well as the number of trajectories to sample for each prefix.
You should see very little output.
```
Restoring parameters from /tmp/fivo/model.ckpt-0
Running local_init_op.
Done running local_init_op.
```
Loading the samples with `np.load` confirms that we conditioned the model on 4
prefixes of length 25 and sampled 4 sequences of length 50 for each prefix.
```
>>> import numpy as np
>>> x = np.load("/tmp/fivo/samples.npz")
>>> x[()]['prefixes'].shape
(25, 4, 88)
>>> x[()]['samples'].shape
(50, 4, 4, 88)
```
### Training on TIMIT
The TIMIT speech dataset is available at the [Linguistic Data Consortium website](https://catalog.ldc.upenn.edu/LDC93S1), but is unfortunately not free. These instructions will proceed assuming you have downloaded the TIMIT archive and extracted it into the directory `$RAW_TIMIT_DIR`.
......@@ -137,7 +189,7 @@ train mean: 0.006060 train std: 548.136169
#### Training on TIMIT
This is very similar to training on pianoroll datasets, with just a few flags switched.
```
python fivo.py \
python run_fivo.py \
--mode=train \
--logdir=/tmp/fivo \
--model=vrnn \
......@@ -149,6 +201,10 @@ python fivo.py \
--dataset_path="$TIMIT_DIR/train" \
--dataset_type="speech"
```
Evaluation and sampling are similar.
### Tests
This codebase comes with a number of tests to verify correctness, runnable via `bin/run_tests.sh`. The tests are also useful to look at for examples of how to use the code.
### Contact
......
......@@ -18,7 +18,7 @@
PIANOROLL_DIR=$HOME/pianorolls
python fivo.py \
python run_fivo.py \
--mode=eval \
--logdir=/tmp/fivo \
--model=vrnn \
......
#!/bin/bash
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# An example of sampling from the model.
PIANOROLL_DIR=$HOME/pianorolls
python run_fivo.py \
--mode=sample \
--alsologtostderr \
--logdir="/tmp/fivo" \
--model=vrnn \
--bound=fivo \
--batch_size=4 \
--num_samples=4 \
--split=test \
--dataset_path="$PIANOROLL_DIR/jsb.pkl" \
--dataset_type="pianoroll" \
--prefix_length=25 \
--sample_length=50
#!/bin/bash
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
python -m fivo.smc_test && \
python -m fivo.bounds_test && \
python -m fivo.nested_utils_test && \
python -m fivo.data.datasets_test && \
python -m fivo.models.ghmm_test && \
python -m fivo.models.vrnn_test && \
python -m fivo.models.srnn_test && \
python -m fivo.ghmm_runners_test && \
python -m fivo.runners_test
......@@ -18,7 +18,7 @@
PIANOROLL_DIR=$HOME/pianorolls
python fivo.py \
python run_fivo.py \
--mode=train \
--logdir=/tmp/fivo \
--model=vrnn \
......
An experimental codebase for running simple examples.
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import namedtuple
import tensorflow as tf
import summary_utils as summ
Loss = namedtuple("Loss", "name loss vars")
Loss.__new__.__defaults__ = (tf.GraphKeys.TRAINABLE_VARIABLES,)
def iwae(model, observation, num_timesteps, num_samples=1,
summarize=False):
"""Compute the IWAE evidence lower bound.
Args:
model: A callable that computes one timestep of the model.
observation: A shape [batch_size*num_samples, state_size] Tensor
containing z_n, the observation for each sequence in the batch.
num_timesteps: The number of timesteps in each sequence, an integer.
num_samples: The number of samples to use to compute the IWAE bound.
Returns:
log_p_hat: The IWAE estimator of the lower bound on the log marginal.
loss: A tensor that you can perform gradient descent on to optimize the
bound.
maintain_ema_op: A no-op included for compatibility with FIVO.
states: The sequence of states sampled.
"""
# Initialization
num_instances = tf.shape(observation)[0]
batch_size = tf.cast(num_instances / num_samples, tf.int32)
states = [model.zero_state(num_instances)]
log_weights = []
log_weight_acc = tf.zeros([num_samples, batch_size], dtype=observation.dtype)
for t in xrange(num_timesteps):
# run the model for one timestep
(zt, log_q_zt, log_p_zt, log_p_x_given_z, _) = model(
states[-1], observation, t)
# update accumulators
states.append(zt)
log_weight = log_p_zt + log_p_x_given_z - log_q_zt
log_weight_acc += tf.reshape(log_weight, [num_samples, batch_size])
if summarize:
weight_dist = tf.contrib.distributions.Categorical(
logits=tf.transpose(log_weight_acc, perm=[1, 0]),
allow_nan_stats=False)
weight_entropy = weight_dist.entropy()
weight_entropy = tf.reduce_mean(weight_entropy)
tf.summary.scalar("weight_entropy/%d" % t, weight_entropy)
log_weights.append(log_weight_acc)
# Compute the lower bound on the log evidence.
log_p_hat = (tf.reduce_logsumexp(log_weight_acc, axis=0) -
tf.log(tf.cast(num_samples, observation.dtype))) / num_timesteps
loss = -tf.reduce_mean(log_p_hat)
losses = [Loss("log_p_hat", loss)]
# we clip off the initial state before returning.
# there are no emas for iwae, so we return a noop for that
return log_p_hat, losses, tf.no_op(), states[1:], log_weights
def multinomial_resampling(log_weights, states, n, b):
"""Resample states with multinomial resampling.
Args:
log_weights: A (n x b) Tensor representing a batch of b logits for n-ary
Categorical distribution.
states: A list of (b*n x d) Tensors that will be resample in from the groups
of every n-th row.
Returns:
resampled_states: A list of (b*n x d) Tensors resampled via stratified sampling.
log_probs: A (n x b) Tensor of the log probabilities of the ancestry decisions.
resampling_parameters: The Tensor of parameters of the resampling distribution.
ancestors: An (n x b) Tensor of integral indices representing the ancestry decisions.
resampling_dist: The distribution object for resampling.
"""
log_weights = tf.convert_to_tensor(log_weights)
states = [tf.convert_to_tensor(state) for state in states]
resampling_parameters = tf.transpose(log_weights, perm=[1,0])
resampling_dist = tf.contrib.distributions.Categorical(logits=resampling_parameters)
ancestors = tf.stop_gradient(
resampling_dist.sample(sample_shape=n))
log_probs = resampling_dist.log_prob(ancestors)
offset = tf.expand_dims(tf.range(b), 0)
ancestor_inds = tf.reshape(ancestors * b + offset, [-1])
resampled_states = []
for state in states:
resampled_states.append(tf.gather(state, ancestor_inds))
return resampled_states, log_probs, resampling_parameters, ancestors, resampling_dist
def stratified_resampling(log_weights, states, n, b):
"""Resample states with straitified resampling.
Args:
log_weights: A (n x b) Tensor representing a batch of b logits for n-ary
Categorical distribution.
states: A list of (b*n x d) Tensors that will be resample in from the groups
of every n-th row.
Returns:
resampled_states: A list of (b*n x d) Tensors resampled via stratified sampling.
log_probs: A (n x b) Tensor of the log probabilities of the ancestry decisions.
resampling_parameters: The Tensor of parameters of the resampling distribution.
ancestors: An (n x b) Tensor of integral indices representing the ancestry decisions.
resampling_dist: The distribution object for resampling.
"""
log_weights = tf.convert_to_tensor(log_weights)
states = [tf.convert_to_tensor(state) for state in states]
log_weights = tf.transpose(log_weights, perm=[1,0])
probs = tf.nn.softmax(
tf.tile(tf.expand_dims(log_weights, axis=1),
[1, n, 1])
)
cdfs = tf.concat([tf.zeros((b,n,1), dtype=probs.dtype), tf.cumsum(probs, axis=2)], 2)
bins = tf.range(n, dtype=probs.dtype) / n
bins = tf.tile(tf.reshape(bins, [1,-1,1]), [b,1,n+1])
strat_cdfs = tf.minimum(tf.maximum((cdfs - bins) * n, 0.0), 1.0)
resampling_parameters = strat_cdfs[:,:,1:] - strat_cdfs[:,:,:-1]
resampling_dist = tf.contrib.distributions.Categorical(
probs = resampling_parameters,
allow_nan_stats=False)
ancestors = tf.stop_gradient(
resampling_dist.sample())
log_probs = resampling_dist.log_prob(ancestors)
ancestors = tf.transpose(ancestors, perm=[1,0])
log_probs = tf.transpose(log_probs, perm=[1,0])
offset = tf.expand_dims(tf.range(b), 0)
ancestor_inds = tf.reshape(ancestors * b + offset, [-1])
resampled_states = []
for state in states:
resampled_states.append(tf.gather(state, ancestor_inds))
return resampled_states, log_probs, resampling_parameters, ancestors, resampling_dist
def systematic_resampling(log_weights, states, n, b):
"""Resample states with systematic resampling.
Args:
log_weights: A (n x b) Tensor representing a batch of b logits for n-ary
Categorical distribution.
states: A list of (b*n x d) Tensors that will be resample in from the groups
of every n-th row.
Returns:
resampled_states: A list of (b*n x d) Tensors resampled via stratified sampling.
log_probs: A (n x b) Tensor of the log probabilities of the ancestry decisions.
resampling_parameters: The Tensor of parameters of the resampling distribution.
ancestors: An (n x b) Tensor of integral indices representing the ancestry decisions.
resampling_dist: The distribution object for resampling.
"""
log_weights = tf.convert_to_tensor(log_weights)
states = [tf.convert_to_tensor(state) for state in states]
log_weights = tf.transpose(log_weights, perm=[1,0])
probs = tf.nn.softmax(
tf.tile(tf.expand_dims(log_weights, axis=1),
[1, n, 1])
)
cdfs = tf.concat([tf.zeros((b,n,1), dtype=probs.dtype), tf.cumsum(probs, axis=2)], 2)
bins = tf.range(n, dtype=probs.dtype) / n
bins = tf.tile(tf.reshape(bins, [1,-1,1]), [b,1,n+1])
strat_cdfs = tf.minimum(tf.maximum((cdfs - bins) * n, 0.0), 1.0)
resampling_parameters = strat_cdfs[:,:,1:] - strat_cdfs[:,:,:-1]
resampling_dist = tf.contrib.distributions.Categorical(
probs=resampling_parameters,
allow_nan_stats=True)
U = tf.random_uniform((b, 1, 1), dtype=probs.dtype)
ancestors = tf.stop_gradient(tf.reduce_sum(tf.to_float(U > strat_cdfs[:,:,1:]), axis=-1))
log_probs = resampling_dist.log_prob(ancestors)
ancestors = tf.transpose(ancestors, perm=[1,0])
log_probs = tf.transpose(log_probs, perm=[1,0])
offset = tf.expand_dims(tf.range(b, dtype=probs.dtype), 0)
ancestor_inds = tf.reshape(ancestors * b + offset, [-1])
resampled_states = []
for state in states:
resampled_states.append(tf.gather(state, ancestor_inds))
return resampled_states, log_probs, resampling_parameters, ancestors, resampling_dist
def log_blend(inputs, weights):
"""Blends state in the log space.
Args:
inputs: A set of scalar states, one for each particle in each particle filter.
Should be [num_samples, batch_size].
weights: A set of weights used to blend the state. Each set of weights
should be of dimension [num_samples] (one weight for each previous particle).
There should be one set of weights for each new particle in each particle filter.
Thus the shape should be [num_samples, batch_size, num_samples] where
the first axis indexes new particle and the last axis indexes old particles.
Returns:
blended: The blended states, a tensor of shape [num_samples, batch_size].
"""
raw_max = tf.reduce_max(inputs, axis=0, keepdims=True)
my_max = tf.stop_gradient(
tf.where(tf.is_finite(raw_max), raw_max, tf.zeros_like(raw_max))
)
# Don't ask.
blended = tf.log(tf.einsum("ijk,kj->ij", weights, tf.exp(inputs - raw_max))) + my_max
return blended
def relaxed_resampling(log_weights, states, num_samples, batch_size,
log_r_x=None, blend_type="log", temperature=0.5,
straight_through=False):
"""Resample states with relaxed resampling.
Args:
log_weights: A (n x b) Tensor representing a batch of b logits for n-ary
Categorical distribution.
states: A list of (b*n x d) Tensors that will be resample in from the groups
of every n-th row.
Returns:
resampled_states: A list of (b*n x d) Tensors resampled via stratified sampling.
log_probs: A (n x b) Tensor of the log probabilities of the ancestry decisions.
resampling_parameters: The Tensor of parameters of the resampling distribution.
ancestors: An (n x b x n) Tensor of relaxed one hot representations of the ancestry decisions.
resampling_dist: The distribution object for resampling.
"""
assert blend_type in ["log", "linear"], "Blend type must be 'log' or 'linear'."
log_weights = tf.convert_to_tensor(log_weights)
states = [tf.convert_to_tensor(state) for state in states]
state_dim = states[0].get_shape().as_list()[-1]
# weights are num_samples by batch_size, so we transpose to get a
# set of batch_size distributions over [0,num_samples).
resampling_parameters = tf.transpose(log_weights, perm=[1, 0])
resampling_dist = tf.contrib.distributions.RelaxedOneHotCategorical(
temperature,
logits=resampling_parameters)
# sample num_samples samples from the distribution, resulting in a
# [num_samples, batch_size, num_samples] Tensor that represents a set of
# [num_samples, batch_size] blending weights. The dimensions represent
# [sample index, batch index, blending weight index]
ancestors = resampling_dist.sample(sample_shape=num_samples)
if straight_through:
# Forward pass discrete choices, backwards pass soft choices
hard_ancestor_indices = tf.argmax(ancestors, axis=-1)
hard_ancestors = tf.one_hot(hard_ancestor_indices, num_samples,
dtype=ancestors.dtype)
ancestors = tf.stop_gradient(hard_ancestors - ancestors) + ancestors
log_probs = resampling_dist.log_prob(ancestors)
if log_r_x is not None and blend_type == "log":
log_r_x = tf.reshape(log_r_x, [num_samples, batch_size])
log_r_x = log_blend(log_r_x, ancestors)
log_r_x = tf.reshape(log_r_x, [num_samples*batch_size])
elif log_r_x is not None and blend_type == "linear":
# If blend type is linear just add log_r to the states that will be blended
# linearly.
states.append(log_r_x)
# transpose the 'indices' to be [batch_index, blending weight index, sample index]
ancestor_inds = tf.transpose(ancestors, perm=[1, 2, 0])
resampled_states = []
for state in states:
# state is currently [num_samples * batch_size, state_dim] so we reshape
# to [num_samples, batch_size, state_dim] and then transpose to
# [batch_size, state_size, num_samples]
state = tf.transpose(tf.reshape(state, [num_samples, batch_size, -1]), perm=[1, 2, 0])
# state is now (batch_size, state_size, num_samples)
# and ancestor is (batch index, blending weight index, sample index)
# multiplying these gives a matrix of size [batch_size, state_size, num_samples]
next_state = tf.matmul(state, ancestor_inds)
# transpose the state to be [num_samples, batch_size, state_size]
# and then reshape it to match the state format.
next_state = tf.reshape(tf.transpose(next_state, perm=[2,0,1]), [num_samples*batch_size, state_dim])
resampled_states.append(next_state)
new_dist = tf.contrib.distributions.Categorical(
logits=resampling_parameters)
if log_r_x is not None and blend_type == "linear":
# If blend type is linear pop off log_r that we added to the states.
log_r_x = tf.squeeze(resampled_states[-1])
resampled_states = resampled_states[:-1]
return resampled_states, log_probs, log_r_x, resampling_parameters, ancestors, new_dist
def fivo(model,
observation,
num_timesteps,
resampling_schedule,
num_samples=1,
use_resampling_grads=True,
resampling_type="multinomial",
resampling_temperature=0.5,
aux=True,
summarize=False):
"""Compute the FIVO evidence lower bound.
Args:
model: A callable that computes one timestep of the model.
observation: A shape [batch_size*num_samples, state_size] Tensor
containing z_n, the observation for each sequence in the batch.
num_timesteps: The number of timesteps in each sequence, an integer.
resampling_schedule: A list of booleans of length num_timesteps, contains
True if a resampling should occur on a specific timestep.
num_samples: The number of samples to use to compute the IWAE bound.
use_resampling_grads: Whether or not to include the resampling gradients
in loss.
resampling type: The type of resampling, one of "multinomial", "stratified",
"relaxed-logblend", "relaxed-linearblend", "relaxed-stateblend", or
"systematic".
resampling_temperature: A positive temperature only used for relaxed
resampling.
aux: If true, compute the FIVO-AUX bound.
Returns:
log_p_hat: The IWAE estimator of the lower bound on the log marginal.
loss: A tensor that you can perform gradient descent on to optimize the
bound.
maintain_ema_op: An op to update the baseline ema used for the resampling
gradients.
states: The sequence of states sampled.
"""
# Initialization
num_instances = tf.cast(tf.shape(observation)[0], tf.int32)
batch_size = tf.cast(num_instances / num_samples, tf.int32)
states = [model.zero_state(num_instances)]
prev_state = states[0]
log_weight_acc = tf.zeros(shape=[num_samples, batch_size], dtype=observation.dtype)
prev_log_r_zt = tf.zeros([num_instances], dtype=observation.dtype)
log_weights = []
log_weights_all = []
log_p_hats = []
resampling_log_probs = []
for t in xrange(num_timesteps):
# run the model for one timestep
(zt, log_q_zt, log_p_zt, log_p_x_given_z, log_r_zt) = model(
prev_state, observation, t)
# update accumulators
states.append(zt)
log_weight = log_p_zt + log_p_x_given_z - log_q_zt
if aux:
if t == num_timesteps - 1:
log_weight -= prev_log_r_zt
else:
log_weight += log_r_zt - prev_log_r_zt
prev_log_r_zt = log_r_zt
log_weight_acc += tf.reshape(log_weight, [num_samples, batch_size])
log_weights_all.append(log_weight_acc)
if resampling_schedule[t]:
# These objects will be resampled
to_resample = [states[-1]]
if aux and "relaxed" not in resampling_type:
to_resample.append(prev_log_r_zt)
# do the resampling
if resampling_type == "multinomial":
(resampled,
resampling_log_prob,
_, _, _) = multinomial_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size)
elif resampling_type == "stratified":
(resampled,
resampling_log_prob,
_, _, _) = stratified_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size)
elif resampling_type == "systematic":
(resampled,
resampling_log_prob,
_, _, _) = systematic_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size)
elif "relaxed" in resampling_type:
if aux:
if resampling_type == "relaxed-logblend":
(resampled,
resampling_log_prob,
prev_log_r_zt,
_, _, _) = relaxed_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size,
temperature=resampling_temperature,
log_r_x=prev_log_r_zt,
blend_type="log")
elif resampling_type == "relaxed-linearblend":
(resampled,
resampling_log_prob,
prev_log_r_zt,
_, _, _) = relaxed_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size,
temperature=resampling_temperature,
log_r_x=prev_log_r_zt,
blend_type="linear")
elif resampling_type == "relaxed-stateblend":
(resampled,
resampling_log_prob,
_, _, _, _) = relaxed_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size,
temperature=resampling_temperature)
# Calculate prev_log_r_zt from the post-resampling state
prev_r_zt = model.r.r_xn(resampled[0], t)
prev_log_r_zt = tf.reduce_sum(
prev_r_zt.log_prob(observation), axis=[1])
elif resampling_type == "relaxed-stateblend-st":
(resampled,
resampling_log_prob,
_, _, _, _) = relaxed_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size,
temperature=resampling_temperature,
straight_through=True)
# Calculate prev_log_r_zt from the post-resampling state
prev_r_zt = model.r.r_xn(resampled[0], t)
prev_log_r_zt = tf.reduce_sum(
prev_r_zt.log_prob(observation), axis=[1])
else:
(resampled,
resampling_log_prob,
_, _, _, _) = relaxed_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size,
temperature=resampling_temperature)
#if summarize:
# resampling_entropy = resampling_dist.entropy()
# resampling_entropy = tf.reduce_mean(resampling_entropy)
# tf.summary.scalar("weight_entropy/%d" % t, resampling_entropy)
resampling_log_probs.append(tf.reduce_sum(resampling_log_prob, axis=0))
prev_state = resampled[0]
if aux and "relaxed" not in resampling_type:
# Squeeze out the extra dim potentially added by resampling.
# prev_log_r_zt should always be [num_instances]
prev_log_r_zt = tf.squeeze(resampled[1])
# Update the log p hat estimate, taking a log sum exp over the sample
# dimension. The appended tensor is [batch_size].
log_p_hats.append(
tf.reduce_logsumexp(log_weight_acc, axis=0) - tf.log(
tf.cast(num_samples, dtype=observation.dtype)))
# reset the weights
log_weights.append(log_weight_acc)
log_weight_acc = tf.zeros_like(log_weight_acc)
else:
prev_state = states[-1]
# Compute the final weight update. If we just resampled this will be zero.
final_update = (tf.reduce_logsumexp(log_weight_acc, axis=0) -
tf.log(tf.cast(num_samples, dtype=observation.dtype)))
# If we ever resampled, then sum up the previous log p hat terms
if len(log_p_hats) > 0:
log_p_hat = tf.reduce_sum(log_p_hats, axis=0) + final_update
else: # otherwise, log_p_hat only comes from the final update
log_p_hat = final_update
if use_resampling_grads and any(resampling_schedule):
# compute the rewards
# cumsum([a, b, c]) => [a, a+b, a+b+c]
# learning signal at timestep t is
# [sum from i=t+1 to T of log_p_hat_i for t=1:T]
# so we will compute (sum from i=1 to T of log_p_hat_i)
# and at timestep t will subtract off (sum from i=1 to t of log_p_hat_i)
# rewards is a [num_resampling_events, batch_size] Tensor
rewards = tf.stop_gradient(
tf.expand_dims(log_p_hat, 0) - tf.cumsum(log_p_hats, axis=0))
batch_avg_rewards = tf.reduce_mean(rewards, axis=1)
# compute ema baseline.
# centered_rewards is [num_resampling_events, batch_size]
baseline_ema = tf.train.ExponentialMovingAverage(decay=0.94)
maintain_baseline_op = baseline_ema.apply([batch_avg_rewards])
baseline = tf.expand_dims(baseline_ema.average(batch_avg_rewards), 1)
centered_rewards = rewards - baseline
if summarize:
summ.summarize_learning_signal(rewards, "rewards")
summ.summarize_learning_signal(centered_rewards, "centered_rewards")
# compute the loss tensor.
resampling_grads = tf.reduce_sum(
tf.stop_gradient(centered_rewards) * resampling_log_probs, axis=0)
losses = [Loss("log_p_hat", -tf.reduce_mean(log_p_hat)/num_timesteps),
Loss("resampling_grads", -tf.reduce_mean(resampling_grads)/num_timesteps)]
else:
losses = [Loss("log_p_hat", -tf.reduce_mean(log_p_hat)/num_timesteps)]
maintain_baseline_op = tf.no_op()
log_p_hat /= num_timesteps
# we clip off the initial state before returning.
return log_p_hat, losses, maintain_baseline_op, states[1:], log_weights_all
def fivo_aux_td(
model,
observation,
num_timesteps,
resampling_schedule,
num_samples=1,
summarize=False):
"""Compute the FIVO_AUX evidence lower bound."""
# Initialization
num_instances = tf.cast(tf.shape(observation)[0], tf.int32)
batch_size = tf.cast(num_instances / num_samples, tf.int32)
states = [model.zero_state(num_instances)]
prev_state = states[0]
log_weight_acc = tf.zeros(shape=[num_samples, batch_size], dtype=observation.dtype)
prev_log_r = tf.zeros([num_instances], dtype=observation.dtype)
# must be pre-resampling
log_rs = []
# must be post-resampling
r_tilde_params = [model.r_tilde.r_zt(states[0], observation, 0)]
log_r_tildes = []
log_p_xs = []
# contains the weight at each timestep before resampling only on resampling timesteps
log_weights = []
# contains weight at each timestep before resampling
log_weights_all = []
log_p_hats = []
for t in xrange(num_timesteps):
# run the model for one timestep
# zt is state, [num_instances, state_dim]
# log_q_zt, log_p_x_given_z is [num_instances]
# r_tilde_mu, r_tilde_sigma is [num_instances, state_dim]
# p_ztplus1 is a normal distribution on [num_instances, state_dim]
(zt, log_q_zt, log_p_zt, log_p_x_given_z,
r_tilde_mu, r_tilde_sigma_sq, p_ztplus1) = model(prev_state, observation, t)
# Compute the log weight without log r.
log_weight = log_p_zt + log_p_x_given_z - log_q_zt
# Compute log r.
if t == num_timesteps - 1:
log_r = tf.zeros_like(prev_log_r)
else:
p_mu = p_ztplus1.mean()
p_sigma_sq = p_ztplus1.variance()
log_r = (tf.log(r_tilde_sigma_sq) -
tf.log(r_tilde_sigma_sq + p_sigma_sq) -
tf.square(r_tilde_mu - p_mu)/(r_tilde_sigma_sq + p_sigma_sq))
log_r = 0.5*tf.reduce_sum(log_r, axis=-1)
#log_weight += tf.stop_gradient(log_r - prev_log_r)
log_weight += log_r - prev_log_r
log_weight_acc += tf.reshape(log_weight, [num_samples, batch_size])
# Update accumulators
states.append(zt)
log_weights_all.append(log_weight_acc)
log_p_xs.append(log_p_x_given_z)
log_rs.append(log_r)
# Compute log_r_tilde as [num_instances] Tensor.
prev_r_tilde_mu, prev_r_tilde_sigma_sq = r_tilde_params[-1]
prev_log_r_tilde = -0.5*tf.reduce_sum(
tf.square(zt - prev_r_tilde_mu)/prev_r_tilde_sigma_sq, axis=-1)
#tf.square(tf.stop_gradient(zt) - r_tilde_mu)/r_tilde_sigma_sq, axis=-1)
#tf.square(zt - r_tilde_mu)/r_tilde_sigma_sq, axis=-1)
log_r_tildes.append(prev_log_r_tilde)
# optionally resample
if resampling_schedule[t]:
# These objects will be resampled
if t < num_timesteps - 1:
to_resample = [zt, log_r, r_tilde_mu, r_tilde_sigma_sq]
else:
to_resample = [zt, log_r]
(resampled,
_, _, _, _) = multinomial_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size)
prev_state = resampled[0]
# Squeeze out the extra dim potentially added by resampling.
# prev_log_r_zt and log_r_tilde should always be [num_instances]
prev_log_r = tf.squeeze(resampled[1])
if t < num_timesteps -1:
r_tilde_params.append((resampled[2], resampled[3]))
# Update the log p hat estimate, taking a log sum exp over the sample
# dimension. The appended tensor is [batch_size].
log_p_hats.append(
tf.reduce_logsumexp(log_weight_acc, axis=0) - tf.log(
tf.cast(num_samples, dtype=observation.dtype)))
# reset the weights
log_weights.append(log_weight_acc)
log_weight_acc = tf.zeros_like(log_weight_acc)
else:
prev_state = zt
prev_log_r = log_r
if t < num_timesteps - 1:
r_tilde_params.append((r_tilde_mu, r_tilde_sigma_sq))
# Compute the final weight update. If we just resampled this will be zero.
final_update = (tf.reduce_logsumexp(log_weight_acc, axis=0) -
tf.log(tf.cast(num_samples, dtype=observation.dtype)))
# If we ever resampled, then sum up the previous log p hat terms
if len(log_p_hats) > 0:
log_p_hat = tf.reduce_sum(log_p_hats, axis=0) + final_update
else: # otherwise, log_p_hat only comes from the final update
log_p_hat = final_update
# Compute the bellman loss.
# Will remove the first timestep as it is not used.
# log p(x_t|z_t) is in row t-1.
log_p_x = tf.reshape(tf.stack(log_p_xs),
[num_timesteps, num_samples, batch_size])
# log r_t is contained in row t-1.
# last column is zeros (because at timestep T (num_timesteps) r is 1.
log_r = tf.reshape(tf.stack(log_rs),
[num_timesteps, num_samples, batch_size])
# [num_timesteps, num_instances]. log r_tilde_t is in row t-1.
log_r_tilde = tf.reshape(tf.stack(log_r_tildes),
[num_timesteps, num_samples, batch_size])
log_lambda = tf.reduce_mean(log_r_tilde - log_p_x - log_r, axis=1,
keepdims=True)
bellman_sos = tf.reduce_mean(tf.square(
log_r_tilde - tf.stop_gradient(log_lambda + log_p_x + log_r)), axis=[0, 1])
bellman_loss = tf.reduce_mean(bellman_sos)/num_timesteps
tf.summary.scalar("bellman_loss", bellman_loss)
if len(tf.get_collection("LOG_P_HAT_VARS")) == 0:
log_p_hat_collection = list(set(tf.trainable_variables()) -
set(tf.get_collection("R_TILDE_VARS")))
for v in log_p_hat_collection:
tf.add_to_collection("LOG_P_HAT_VARS", v)
log_p_hat /= num_timesteps
losses = [Loss("log_p_hat", -tf.reduce_mean(log_p_hat), "LOG_P_HAT_VARS"),
Loss("bellman_loss", bellman_loss, "R_TILDE_VARS")]
return log_p_hat, losses, tf.no_op(), states[1:], log_weights_all
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Datasets."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
import models
def make_long_chain_dataset(
state_size=1,
num_obs=5,
steps_per_obs=3,
variance=1.,
observation_variance=1.,
batch_size=4,
num_samples=1,
observation_type=models.STANDARD_OBSERVATION,
transition_type=models.STANDARD_TRANSITION,
fixed_observation=None,
dtype="float32"):
"""Creates a long chain data generating process.
Creates a tf.data.Dataset that provides batches of data from a long
chain.
Args:
state_size: The dimension of the state space of the process.
num_obs: The number of observations in the chain.
steps_per_obs: The number of steps between each observation.
variance: The variance of the normal distributions used at each timestep.
batch_size: The number of trajectories to include in each batch.
num_samples: The number of replicas of each trajectory to include in each
batch.
dtype: The datatype of the states and observations.
Returns:
dataset: A tf.data.Dataset that can be iterated over.
"""
num_timesteps = num_obs * steps_per_obs
def data_generator():
"""An infinite generator of latents and observations from the model."""
while True:
states = []
observations = []
# z0 ~ Normal(0, sqrt(variance)).
states.append(
np.random.normal(size=[state_size],
scale=np.sqrt(variance)).astype(dtype))
# start at 1 because we've already generated z0
# go to num_timesteps+1 because we want to include the num_timesteps-th step
for t in xrange(1, num_timesteps+1):
if transition_type == models.ROUND_TRANSITION:
loc = np.round(states[-1])
elif transition_type == models.STANDARD_TRANSITION:
loc = states[-1]
new_state = np.random.normal(size=[state_size],
loc=loc,
scale=np.sqrt(variance))
states.append(new_state.astype(dtype))
if t % steps_per_obs == 0:
if fixed_observation is None:
if observation_type == models.SQUARED_OBSERVATION:
loc = np.square(states[-1])
elif observation_type == models.ABS_OBSERVATION:
loc = np.abs(states[-1])
elif observation_type == models.STANDARD_OBSERVATION:
loc = states[-1]
new_obs = np.random.normal(size=[state_size],
loc=loc,
scale=np.sqrt(observation_variance)).astype(dtype)
else:
new_obs = np.ones([state_size])* fixed_observation
observations.append(new_obs)
yield states, observations
dataset = tf.data.Dataset.from_generator(
data_generator,
output_types=(tf.as_dtype(dtype), tf.as_dtype(dtype)),
output_shapes=([num_timesteps+1, state_size], [num_obs, state_size]))
dataset = dataset.repeat().batch(batch_size)
def tile_batch(state, observation):
state = tf.tile(state, [num_samples, 1, 1])
observation = tf.tile(observation, [num_samples, 1, 1])
return state, observation
dataset = dataset.map(tile_batch, num_parallel_calls=12).prefetch(1024)
return dataset
def make_dataset(bs=None,
state_size=1,
num_timesteps=10,
variance=1.,
prior_type="unimodal",
bimodal_prior_weight=0.5,
bimodal_prior_mean=1,
transition_type=models.STANDARD_TRANSITION,
fixed_observation=None,
batch_size=4,
num_samples=1,
dtype='float32'):
"""Creates a data generating process.
Creates a tf.data.Dataset that provides batches of data.
Args:
bs: The parameters of the data generating process. If None, new bs are
randomly generated.
state_size: The dimension of the state space of the process.
num_timesteps: The length of the state sequences in the process.
variance: The variance of the normal distributions used at each timestep.
batch_size: The number of trajectories to include in each batch.
num_samples: The number of replicas of each trajectory to include in each
batch.
Returns:
bs: The true bs used to generate the data
dataset: A tf.data.Dataset that can be iterated over.
"""
if bs is None:
bs = [np.random.uniform(size=[state_size]).astype(dtype) for _ in xrange(num_timesteps)]
tf.logging.info("data generating processs bs: %s",
np.array(bs).reshape(num_timesteps))
def data_generator():
"""An infinite generator of latents and observations from the model."""
while True:
states = []
if prior_type == "unimodal" or prior_type == "nonlinear":
# Prior is Normal(0, sqrt(variance)).
states.append(np.random.normal(size=[state_size], scale=np.sqrt(variance)).astype(dtype))
elif prior_type == "bimodal":
if np.random.uniform() > bimodal_prior_weight:
loc = bimodal_prior_mean
else:
loc = - bimodal_prior_mean
states.append(np.random.normal(size=[state_size],
loc=loc,
scale=np.sqrt(variance)
).astype(dtype))
for t in xrange(num_timesteps):
if transition_type == models.ROUND_TRANSITION:
loc = np.round(states[-1])
elif transition_type == models.STANDARD_TRANSITION:
loc = states[-1]
loc += bs[t]
new_state = np.random.normal(size=[state_size],
loc=loc,
scale=np.sqrt(variance)).astype(dtype)
states.append(new_state)
if fixed_observation is None:
observation = states[-1]
else:
observation = np.ones_like(states[-1]) * fixed_observation
yield np.array(states[:-1]), observation
dataset = tf.data.Dataset.from_generator(
data_generator,
output_types=(tf.as_dtype(dtype), tf.as_dtype(dtype)),
output_shapes=([num_timesteps, state_size], [state_size]))
dataset = dataset.repeat().batch(batch_size)
def tile_batch(state, observation):
state = tf.tile(state, [num_samples, 1, 1])
observation = tf.tile(observation, [num_samples, 1])
return state, observation
dataset = dataset.map(tile_batch, num_parallel_calls=12).prefetch(1024)
return np.array(bs), dataset
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment