Commit 5655c705 authored by Lukasz Kaiser's avatar Lukasz Kaiser Committed by GitHub
Browse files

Merge pull request #300 from panyx0718/models-textsum

Add text summarization model to tensorflow/models.
parents c711dc70 56a05f68
package(default_visibility = [":internal"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
package_group(
name = "internal",
packages = [
"//textsum/...",
],
)
py_library(
name = "seq2seq_attention_model",
srcs = ["seq2seq_attention_model.py"],
deps = [
":seq2seq_lib",
],
)
py_library(
name = "seq2seq_lib",
srcs = ["seq2seq_lib.py"],
)
py_binary(
name = "seq2seq_attention",
srcs = ["seq2seq_attention.py"],
deps = [
":batch_reader",
":data",
":seq2seq_attention_decode",
":seq2seq_attention_model",
],
)
py_library(
name = "batch_reader",
srcs = ["batch_reader.py"],
deps = [
":data",
":seq2seq_attention_model",
],
)
py_library(
name = "beam_search",
srcs = ["beam_search.py"],
)
py_library(
name = "seq2seq_attention_decode",
srcs = ["seq2seq_attention_decode.py"],
deps = [
":beam_search",
":data",
],
)
py_library(
name = "data",
srcs = ["data.py"],
)
Sequence-to-Sequence with Attention Model for Text Summarization.
Authors:
Xin Pan (xpan@google.com, github:panyx0718), Peter Liu (peterjliu@google.com)
<b>Introduction</b>
The core model is the traditional seqeuence-to-sequence model with attention.
It is customized (mostly inputs/outputs) for the text summarization task. The
model has been trained on Gigaword dataset and achieved state-of-the-art
results (as of June 2016).
The results described below are based on model trained on multi-gpu and
multi-machine settings. It has been simplified to run on only one machine
for open source purpose.
<b>DataSet</b>
We used the Gigaword dataset described in
https://arxiv.org/pdf/1602.06023.pdf
We cannot provide the dataset due to the license. See ExampleGen in data.py
about the data format. data/data contains a toy example. Also see data/vocab
for example vocabulary format. In <b>How To Run</b> below, users can use toy
data and vocab provided in the data/ directory to run the training by replacing
the data directory flag.
<b>Experiment Result</b>
8000 examples from testset are sampled to generate summaries and rouge score is
calculated for the generated summaries. Here is the best rouge score on
Gigaword dataset:
ROUGE-1 Average_R: 0.38272 (95%-conf.int. 0.37774 - 0.38755)
ROUGE-1 Average_P: 0.50154 (95%-conf.int. 0.49509 - 0.50780)
ROUGE-1 Average_F: 0.42568 (95%-conf.int. 0.42016 - 0.43099)
ROUGE-2 Average_R: 0.20576 (95%-conf.int. 0.20060 - 0.21112)
ROUGE-2 Average_P: 0.27565 (95%-conf.int. 0.26851 - 0.28257)
ROUGE-2 Average_F: 0.23126 (95%-conf.int. 0.22539 - 0.23708)
<b>Configuration:</b>
Following is the configuration for the best trained model on Gigaword:
batch_size: 64
bidirectional encoding layer: 4
article length: first 2 sentences, total words within 120.
summary length: total words within 30.
word embedding size: 128
LSTM hidden units: 256
Sampled softmax: 4096
vocabulary size: Most frequent 200k words from dataset's article and summaries.
<b>How To Run</b>
Pre-requesite:
Install TensorFlow and Bazel.
```shell
# cd to your workspace
# clone the code to your workspace and create empty WORKSPACE file.
# move the data to your workspace. If don't have full dataset yet, copy
# the toy data from the data/ directory from code directory and rename
# the files.
ls -R
.:
data textsum WORKSPACE
./data:
vocab test-0 training-0 training-1 validation-0 ...(omitted)
./textsum:
batch_reader.py beam_search.py BUILD README.md seq2seq_attention_model.py data
data.py seq2seq_attention_decode.py seq2seq_attention.py seq2seq_lib.py
./textsum/data:
data vocab
bazel build -c opt --config=cuda textsum/...
# Run the training.
bazel-bin/textsum/seq2seq_attention \
--mode=train \
--article_key=article \
--abstract_key=abstract \
--data_path=data/training-* \
--vocab_path=data/vocab \
--log_root=textsum/log_root \
--train_dir=textsum/log_root/train
# Run the eval. Try to avoid running on the same matchine as training.
bazel-bin/textsum/seq2seq_attention \
--mode=eval \
--article_key=article \
--abstract_key=abstract \
--data_path=data/validation-* \
--vocab_path=data/vocab \
--log_root=textsum/log_root \
--eval_dir=textsum/log_root/eval
# Run the decode. Run it when the most is mostly converged.
bazel-bin/textsum/seq2seq_attention \
--mode=decode \
--article_key=article \
--abstract_key=abstract \
--data_path=data/test-* \
--vocab_path=data/vocab \
--log_root=textsum/log_root \
--decode_dir=textsum/log_root/decode \
--beam_size=8
```
<b>Examples:</b>
The following are some text summarization examples, including experiments
using dataset other than Gigaword.
article: novell inc. chief executive officer eric schmidt has been named chairman of the internet search-engine company google .
human: novell ceo named google chairman
machine: novell chief executive named to head internet company
======================================
article: gulf newspapers voiced skepticism thursday over whether newly re - elected us president bill clinton could help revive the troubled middle east peace process but saw a glimmer of hope .
human: gulf skeptical about whether clinton will revive peace process
machine: gulf press skeptical over clinton 's prospects for peace process
======================================
article: the european court of justice ( ecj ) recently ruled in lock v british gas trading ltd that eu law requires a worker 's statutory holiday pay to take commission payments into account - it should not be based solely on basic salary . the case is not over yet , but its outcome could potentially be costly for employers with workers who are entitled to commission . mr lock , an energy salesman for british gas , was paid a basic salary and sales commission on a monthly basis . his sales commission made up around 60 % of his remuneration package . when he took two weeks ' annual leave in december 2012 , he was paid his basic salary and also received commission from previous sales that fell due during that period . lock obviously did not generate new sales while he was on holiday , which meant that in the following period he suffered a reduced income through lack of commission . he brought an employment tribunal claim asserting that this amounted to a breach of the working time regulations 1998 .....deleted rest for readability...
abstract: will british gas ecj ruling fuel holiday pay hike ?
decode: eu law requires worker 's statutory holiday pay
======================================
article: the junior all whites have been eliminated from the fifa u - 20 world cup in colombia with results on the final day of pool play confirming their exit . sitting on two points , new zealand needed results in one of the final two groups to go their way to join the last 16 as one of the four best third place teams . but while spain helped the kiwis ' cause with a 5 - 1 thrashing of australia , a 3 - 0 win for ecuador over costa rica saw the south americans climb to second in group c with costa rica 's three points also good enough to progress in third place . that left the junior all whites hopes hanging on the group d encounter between croatia and honduras finishing in a draw . a stalemate - and a place in the knockout stages for new zealand - appeared on the cards until midfielder marvin ceballos netted an 81st minute winner that sent guatemala through to the second round and left the junior all whites packing their bags . new zealand finishes the 24 - nation tournament in 17th place , having claimed their first ever points at this level in just their second appearance at the finals .
abstract: junior all whites exit world cup
decoded: junior all whites eliminated from u- 20 world cup
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Batch reader to seq2seq attention model, with bucketing support."""
from collections import namedtuple
import Queue
from random import shuffle
from threading import Thread
import time
import numpy as np
import tensorflow as tf
import data
ModelInput = namedtuple('ModelInput',
'enc_input dec_input target enc_len dec_len '
'origin_article origin_abstract')
BUCKET_CACHE_BATCH = 100
QUEUE_NUM_BATCH = 100
class Batcher(object):
"""Batch reader with shuffling and bucketing support."""
def __init__(self, data_path, vocab, hps,
article_key, abstract_key, max_article_sentences,
max_abstract_sentences, bucketing=True, truncate_input=False):
"""Batcher constructor.
Args:
data_path: tf.Example filepattern.
vocab: Vocabulary.
hps: Seq2SeqAttention model hyperparameters.
article_key: article feature key in tf.Example.
abstract_key: abstract feature key in tf.Example.
max_article_sentences: Max number of sentences used from article.
max_abstract_sentences: Max number of sentences used from abstract.
bucketing: Whether bucket articles of similar length into the same batch.
truncate_input: Whether to truncate input that is too long. Alternative is
to discard such examples.
"""
self._data_path = data_path
self._vocab = vocab
self._hps = hps
self._article_key = article_key
self._abstract_key = abstract_key
self._max_article_sentences = max_article_sentences
self._max_abstract_sentences = max_abstract_sentences
self._bucketing = bucketing
self._truncate_input = truncate_input
self._input_queue = Queue.Queue(QUEUE_NUM_BATCH * self._hps.batch_size)
self._bucket_input_queue = Queue.Queue(QUEUE_NUM_BATCH)
self._input_threads = []
for _ in xrange(16):
self._input_threads.append(Thread(target=self._FillInputQueue))
self._input_threads[-1].daemon = True
self._input_threads[-1].start()
self._bucketing_threads = []
for _ in xrange(4):
self._bucketing_threads.append(Thread(target=self._FillBucketInputQueue))
self._bucketing_threads[-1].daemon = True
self._bucketing_threads[-1].start()
self._watch_thread = Thread(target=self._WatchThreads)
self._watch_thread.daemon = True
self._watch_thread.start()
def NextBatch(self):
"""Returns a batch of inputs for seq2seq attention model.
Returns:
enc_batch: A batch of encoder inputs [batch_size, hps.enc_timestamps].
dec_batch: A batch of decoder inputs [batch_size, hps.dec_timestamps].
target_batch: A batch of targets [batch_size, hps.dec_timestamps].
enc_input_len: encoder input lengths of the batch.
dec_input_len: decoder input lengths of the batch.
loss_weights: weights for loss function, 1 if not padded, 0 if padded.
origin_articles: original article words.
origin_abstracts: original abstract words.
"""
enc_batch = np.zeros(
(self._hps.batch_size, self._hps.enc_timesteps), dtype=np.int32)
enc_input_lens = np.zeros(
(self._hps.batch_size), dtype=np.int32)
dec_batch = np.zeros(
(self._hps.batch_size, self._hps.dec_timesteps), dtype=np.int32)
dec_output_lens = np.zeros(
(self._hps.batch_size), dtype=np.int32)
target_batch = np.zeros(
(self._hps.batch_size, self._hps.dec_timesteps), dtype=np.int32)
loss_weights = np.zeros(
(self._hps.batch_size, self._hps.dec_timesteps), dtype=np.float32)
origin_articles = ['None'] * self._hps.batch_size
origin_abstracts = ['None'] * self._hps.batch_size
buckets = self._bucket_input_queue.get()
for i in xrange(self._hps.batch_size):
(enc_inputs, dec_inputs, targets, enc_input_len, dec_output_len,
article, abstract) = buckets[i]
origin_articles[i] = article
origin_abstracts[i] = abstract
enc_input_lens[i] = enc_input_len
dec_output_lens[i] = dec_output_len
enc_batch[i, :] = enc_inputs[:]
dec_batch[i, :] = dec_inputs[:]
target_batch[i, :] = targets[:]
for j in xrange(dec_output_len):
loss_weights[i][j] = 1
return (enc_batch, dec_batch, target_batch, enc_input_lens, dec_output_lens,
loss_weights, origin_articles, origin_abstracts)
def _FillInputQueue(self):
"""Fill input queue with ModelInput."""
start_id = self._vocab.WordToId(data.SENTENCE_START)
end_id = self._vocab.WordToId(data.SENTENCE_END)
pad_id = self._vocab.WordToId(data.PAD_TOKEN)
input_gen = self._TextGenerator(data.ExampleGen(self._data_path))
while True:
(article, abstract) = input_gen.next()
article_sentences = [sent.strip() for sent in
data.ToSentences(article, include_token=False)]
abstract_sentences = [sent.strip() for sent in
data.ToSentences(abstract, include_token=False)]
enc_inputs = []
# Use the <s> as the <GO> symbol for decoder inputs.
dec_inputs = [start_id]
# Convert first N sentences to word IDs, stripping existing <s> and </s>.
for i in xrange(min(self._max_article_sentences,
len(article_sentences))):
enc_inputs += data.GetWordIds(article_sentences[i], self._vocab)
for i in xrange(min(self._max_abstract_sentences,
len(abstract_sentences))):
dec_inputs += data.GetWordIds(abstract_sentences[i], self._vocab)
# Filter out too-short input
if (len(enc_inputs) < self._hps.min_input_len or
len(dec_inputs) < self._hps.min_input_len):
tf.logging.warning('Drop an example - too short.\nenc:%d\ndec:%d',
len(enc_inputs), len(dec_inputs))
continue
# If we're not truncating input, throw out too-long input
if not self._truncate_input:
if (len(enc_inputs) > self._hps.enc_timesteps or
len(dec_inputs) > self._hps.dec_timesteps):
tf.logging.warning('Drop an example - too long.\nenc:%d\ndec:%d',
len(enc_inputs), len(dec_inputs))
continue
# If we are truncating input, do so if necessary
else:
if len(enc_inputs) > self._hps.enc_timesteps:
enc_inputs = enc_inputs[:self._hps.enc_timesteps]
if len(dec_inputs) > self._hps.dec_timesteps:
dec_inputs = dec_inputs[:self._hps.dec_timesteps]
# targets is dec_inputs without <s> at beginning, plus </s> at end
targets = dec_inputs[1:]
targets.append(end_id)
# Now len(enc_inputs) should be <= enc_timesteps, and
# len(targets) = len(dec_inputs) should be <= dec_timesteps
enc_input_len = len(enc_inputs)
dec_output_len = len(targets)
# Pad if necessary
while len(enc_inputs) < self._hps.enc_timesteps:
enc_inputs.append(pad_id)
while len(dec_inputs) < self._hps.dec_timesteps:
dec_inputs.append(end_id)
while len(targets) < self._hps.dec_timesteps:
targets.append(end_id)
element = ModelInput(enc_inputs, dec_inputs, targets, enc_input_len,
dec_output_len, ' '.join(article_sentences),
' '.join(abstract_sentences))
self._input_queue.put(element)
def _FillBucketInputQueue(self):
"""Fill bucketed batches into the bucket_input_queue."""
while True:
inputs = []
for _ in xrange(self._hps.batch_size * BUCKET_CACHE_BATCH):
inputs.append(self._input_queue.get())
if self._bucketing:
inputs = sorted(inputs, key=lambda inp: inp.enc_len)
batches = []
for i in xrange(0, len(inputs), self._hps.batch_size):
batches.append(inputs[i:i+self._hps.batch_size])
shuffle(batches)
for b in batches:
self._bucket_input_queue.put(b)
def _WatchThreads(self):
"""Watch the daemon input threads and restart if dead."""
while True:
time.sleep(60)
input_threads = []
for t in self._input_threads:
if t.is_alive():
input_threads.append(t)
else:
tf.logging.error('Found input thread dead.')
new_t = Thread(target=self._FillInputQueue)
input_threads.append(new_t)
input_threads[-1].daemon = True
input_threads[-1].start()
self._input_threads = input_threads
bucketing_threads = []
for t in self._bucketing_threads:
if t.is_alive():
bucketing_threads.append(t)
else:
tf.logging.error('Found bucketing thread dead.')
new_t = Thread(target=self._FillBucketInputQueue)
bucketing_threads.append(new_t)
bucketing_threads[-1].daemon = True
bucketing_threads[-1].start()
self._bucketing_threads = bucketing_threads
def _TextGenerator(self, example_gen):
"""Generates article and abstract text from tf.Example."""
while True:
e = example_gen.next()
try:
article_text = self._GetExFeatureText(e, self._article_key)
abstract_text = self._GetExFeatureText(e, self._abstract_key)
except ValueError:
tf.logging.error('Failed to get article or abstract from example')
continue
yield (article_text, abstract_text)
def _GetExFeatureText(self, ex, key):
"""Extract text for a feature from td.Example.
Args:
ex: tf.Example.
key: key of the feature to be extracted.
Returns:
feature: a feature text extracted.
"""
return ex.features.feature[key].bytes_list.value[0]
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Beam search module.
Beam search takes the top K results from the model, predicts the K results for
each of the previous K result, getting K*K results. Pick the top K results from
K*K results, and start over again until certain number of results are fully
decoded.
"""
import tensorflow as tf
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_bool('normalize_by_length', True, 'Whether normalize')
class Hypothesis(object):
"""Defines a hypothesis during beam search."""
def __init__(self, tokens, log_prob, state):
"""Hypothesis constructor.
Args:
tokens: start tokens for decoding.
log_prob: log prob of the start tokens, usually 1.
state: decoder initial states.
"""
self.tokens = tokens
self.log_prob = log_prob
self.state = state
def Extend(self, token, log_prob, new_state):
"""Extend the hypothesis with result from latest step.
Args:
token: latest token from decoding.
log_prob: log prob of the latest decoded tokens.
new_state: decoder output state. Fed to the decoder for next step.
Returns:
New Hypothesis with the results from latest step.
"""
return Hypothesis(self.tokens + [token], self.log_prob + log_prob,
new_state)
@property
def latest_token(self):
return self.tokens[-1]
def __str__(self):
return ('Hypothesis(log prob = %.4f, tokens = %s)' % (self.log_prob,
self.tokens))
class BeamSearch(object):
"""Beam search."""
def __init__(self, model, beam_size, start_token, end_token, max_steps):
"""Creates BeamSearch object.
Args:
model: Seq2SeqAttentionModel.
beam_size: int.
start_token: int, id of the token to start decoding with
end_token: int, id of the token that completes an hypothesis
max_steps: int, upper limit on the size of the hypothesis
"""
self._model = model
self._beam_size = beam_size
self._start_token = start_token
self._end_token = end_token
self._max_steps = max_steps
def BeamSearch(self, sess, enc_inputs, enc_seqlen):
"""Performs beam search for decoding.
Args:
sess: tf.Session, session
enc_inputs: ndarray of shape (enc_length, 1), the document ids to encode
enc_seqlen: ndarray of shape (1), the length of the sequnce
Returns:
hyps: list of Hypothesis, the best hypotheses found by beam search,
ordered by score
"""
# Run the encoder and extract the outputs and final state.
enc_top_states, dec_in_state = self._model.encode_top_state(
sess, enc_inputs, enc_seqlen)
# Replicate the initial states K times for the first step.
hyps = [Hypothesis([self._start_token], 0.0, dec_in_state)
] * self._beam_size
results = []
steps = 0
while steps < self._max_steps and len(results) < self._beam_size:
latest_tokens = [h.latest_token for h in hyps]
states = [h.state for h in hyps]
topk_ids, topk_log_probs, new_states = self._model.decode_topk(
sess, latest_tokens, enc_top_states, states)
# Extend each hypothesis.
all_hyps = []
# The first step takes the best K results from first hyps. Following
# steps take the best K results from K*K hyps.
num_beam_source = 1 if steps == 0 else len(hyps)
for i in xrange(num_beam_source):
h, ns = hyps[i], new_states[i]
for j in xrange(self._beam_size*2):
all_hyps.append(h.Extend(topk_ids[i, j], topk_log_probs[i, j], ns))
# Filter and collect any hypotheses that have the end token.
hyps = []
for h in self._BestHyps(all_hyps):
if h.latest_token == self._end_token:
# Pull the hypothesis off the beam if the end token is reached.
results.append(h)
else:
# Otherwise continue to the extend the hypothesis.
hyps.append(h)
if len(hyps) == self._beam_size or len(results) == self._beam_size:
break
steps += 1
if steps == self._max_steps:
results.extend(hyps)
return self._BestHyps(results)
def _BestHyps(self, hyps):
"""Sort the hyps based on log probs and length.
Args:
hyps: A list of hypothesis.
Returns:
hyps: A list of sorted hypothesis in reverse log_prob order.
"""
# This length normalization is only effective for the final results.
if FLAGS.normalize_by_length:
return sorted(hyps, key=lambda h: h.log_prob/len(h.tokens), reverse=True)
else:
return sorted(hyps, key=lambda h: h.log_prob, reverse=True)
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Data batchers for data described in ..//data_prep/README.md."""
import glob
import random
import struct
import sys
from tensorflow.core.example import example_pb2
# Special tokens
PARAGRAPH_START = '<p>'
PARAGRAPH_END = '</p>'
SENTENCE_START = '<s>'
SENTENCE_END = '</s>'
UNKNOWN_TOKEN = '<UNK>'
PAD_TOKEN = '<PAD>'
DOCUMENT_START = '<d>'
DOCUMENT_END = '</d>'
class Vocab(object):
"""Vocabulary class for mapping words and ids."""
def __init__(self, vocab_file, max_size):
self._word_to_id = {}
self._id_to_word = {}
self._count = 0
with open(vocab_file, 'r') as vocab_f:
for line in vocab_f:
pieces = line.split()
if len(pieces) != 2:
sys.stderr.write('Bad line: %s\n' % line)
continue
if pieces[0] in self._word_to_id:
raise ValueError('Duplicated word: %s.' % pieces[0])
self._word_to_id[pieces[0]] = self._count
self._id_to_word[self._count] = pieces[0]
self._count += 1
if self._count > max_size:
raise ValueError('Too many words: >%d.' % max_size)
def WordToId(self, word):
if word not in self._word_to_id:
return self._word_to_id[UNKNOWN_TOKEN]
return self._word_to_id[word]
def IdToWord(self, word_id):
if word_id not in self._id_to_word:
raise ValueError('id not found in vocab: %d.' % word_id)
return self._id_to_word[word_id]
def NumIds(self):
return self._count
def ExampleGen(recordio_path, num_epochs=None):
"""Generates tf.Examples from path of recordio files.
Args:
recordio_path: CNS path to tf.Example recordio
num_epochs: Number of times to go through the data. None means infinite.
Yields:
Deserialized tf.Example.
If there are multiple files specified, they accessed in a random order.
"""
epoch = 0
while True:
if num_epochs is not None and epoch >= num_epochs:
break
filelist = glob.glob(recordio_path)
assert filelist, 'Empty filelist.'
random.shuffle(filelist)
for f in filelist:
reader = open(f, 'rb')
while True:
len_bytes = reader.read(8)
if not len_bytes: break
str_len = struct.unpack('q', len_bytes)[0]
example_str = struct.unpack('%ds' % str_len, reader.read(str_len))[0]
yield example_pb2.Example.FromString(example_str)
epoch += 1
def Pad(ids, pad_id, length):
"""Pad or trim list to len length.
Args:
ids: list of ints to pad
pad_id: what to pad with
length: length to pad or trim to
Returns:
ids trimmed or padded with pad_id
"""
assert pad_id is not None
assert length is not None
if len(ids) < length:
a = [pad_id] * (length - len(ids))
return ids + a
else:
return ids[:length]
def GetWordIds(text, vocab, pad_len=None, pad_id=None):
"""Get ids corresponding to words in text.
Assumes tokens separated by space.
Args:
text: a string
vocab: TextVocabularyFile object
pad_len: int, length to pad to
pad_id: int, word id for pad symbol
Returns:
A list of ints representing word ids.
"""
ids = []
for w in text.split():
i = vocab.WordToId(w)
if i >= 0:
ids.append(i)
else:
ids.append(vocab.WordToId(UNKNOWN_TOKEN))
if pad_len is not None:
return Pad(ids, pad_id, pad_len)
return ids
def Ids2Words(ids_list, vocab):
"""Get words from ids.
Args:
ids_list: list of int32
vocab: TextVocabulary object
Returns:
List of words corresponding to ids.
"""
assert isinstance(ids_list, list), '%s is not a list' % ids_list
return [vocab.IdToWord(i) for i in ids_list]
def SnippetGen(text, start_tok, end_tok, inclusive=True):
"""Generates consecutive snippets between start and end tokens.
Args:
text: a string
start_tok: a string denoting the start of snippets
end_tok: a string denoting the end of snippets
inclusive: Whether include the tokens in the returned snippets.
Yields:
String snippets
"""
cur = 0
while True:
try:
start_p = text.index(start_tok, cur)
end_p = text.index(end_tok, start_p + 1)
cur = end_p + len(end_tok)
if inclusive:
yield text[start_p:cur]
else:
yield text[start_p+len(start_tok):end_p]
except ValueError as e:
raise StopIteration('no more snippets in text: %s' % e)
def GetExFeatureText(ex, key):
return ex.features.feature[key].bytes_list.value[0]
def ToSentences(paragraph, include_token=True):
"""Takes tokens of a paragraph and returns list of sentences.
Args:
paragraph: string, text of paragraph
include_token: Whether include the sentence separation tokens result.
Returns:
List of sentence strings.
"""
s_gen = SnippetGen(paragraph, SENTENCE_START, SENTENCE_END, include_token)
return [s for s in s_gen]
This diff is collapsed.
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trains a seq2seq model.
WORK IN PROGRESS.
Implement "Abstractive Text Summarization using Sequence-to-sequence RNNS and
Beyond."
"""
import sys
import time
import tensorflow as tf
import batch_reader
import data
import seq2seq_attention_decode
import seq2seq_attention_model
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('data_path',
'', 'Path expression to tf.Example.')
tf.app.flags.DEFINE_string('vocab_path',
'', 'Path expression to text vocabulary file.')
tf.app.flags.DEFINE_string('article_key', 'article',
'tf.Example feature key for article.')
tf.app.flags.DEFINE_string('abstract_key', 'headline',
'tf.Example feature key for abstract.')
tf.app.flags.DEFINE_string('log_root', '', 'Directory for model root.')
tf.app.flags.DEFINE_string('train_dir', '', 'Directory for train.')
tf.app.flags.DEFINE_string('eval_dir', '', 'Directory for eval.')
tf.app.flags.DEFINE_string('decode_dir', '', 'Directory for decode summaries.')
tf.app.flags.DEFINE_string('mode', 'train', 'train/eval/decode mode')
tf.app.flags.DEFINE_integer('max_run_steps', 10000000,
'Maximum number of run steps.')
tf.app.flags.DEFINE_integer('max_article_sentences', 2,
'Max number of first sentences to use from the '
'article')
tf.app.flags.DEFINE_integer('max_abstract_sentences', 100,
'Max number of first sentences to use from the '
'abstract')
tf.app.flags.DEFINE_integer('beam_size', 4,
'beam size for beam search decoding.')
tf.app.flags.DEFINE_integer('eval_interval_secs', 60, 'How often to run eval.')
tf.app.flags.DEFINE_integer('checkpoint_secs', 60, 'How often to checkpoint.')
tf.app.flags.DEFINE_bool('use_bucketing', False,
'Whether bucket articles of similar length.')
tf.app.flags.DEFINE_bool('truncate_input', False,
'Truncate inputs that are too long. If False, '
'examples that are too long are discarded.')
tf.app.flags.DEFINE_integer('num_gpus', 0, 'Number of gpus used.')
tf.app.flags.DEFINE_integer('random_seed', 111, 'A seed value for randomness.')
def _RunningAvgLoss(loss, running_avg_loss, summary_writer, step, decay=0.999):
"""Calculate the running average of losses."""
if running_avg_loss == 0:
running_avg_loss = loss
else:
running_avg_loss = running_avg_loss * decay + (1 - decay) * loss
running_avg_loss = min(running_avg_loss, 12)
loss_sum = tf.Summary()
loss_sum.value.add(tag='running_avg_loss', simple_value=running_avg_loss)
summary_writer.add_summary(loss_sum, step)
sys.stdout.write('running_avg_loss: %f\n' % running_avg_loss)
return running_avg_loss
def _Train(model, data_batcher):
"""Runs model training."""
with tf.device('/cpu:0'):
model.build_graph()
saver = tf.train.Saver()
# Train dir is different from log_root to avoid summary directory
# conflict with Supervisor.
summary_writer = tf.train.SummaryWriter(FLAGS.train_dir)
sv = tf.train.Supervisor(logdir=FLAGS.log_root,
is_chief=True,
saver=saver,
summary_op=None,
save_summaries_secs=60,
save_model_secs=FLAGS.checkpoint_secs,
global_step=model.global_step)
sess = sv.prepare_or_wait_for_session()
running_avg_loss = 0
step = 0
while not sv.should_stop() and step < FLAGS.max_run_steps:
(article_batch, abstract_batch, targets, article_lens, abstract_lens,
loss_weights, _, _) = data_batcher.NextBatch()
(_, summaries, loss, train_step) = model.run_train_step(
sess, article_batch, abstract_batch, targets, article_lens,
abstract_lens, loss_weights)
summary_writer.add_summary(summaries, train_step)
running_avg_loss = _RunningAvgLoss(
running_avg_loss, loss, summary_writer, train_step)
step += 1
if step % 100 == 0:
summary_writer.flush()
sv.Stop()
return running_avg_loss
def _Eval(model, data_batcher, vocab=None):
"""Runs model eval."""
model.build_graph()
saver = tf.train.Saver()
summary_writer = tf.train.SummaryWriter(FLAGS.eval_dir)
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
running_avg_loss = 0
step = 0
while True:
time.sleep(FLAGS.eval_interval_secs)
try:
ckpt_state = tf.train.get_checkpoint_state(FLAGS.log_root)
except tf.errors.OutOfRangeError as e:
tf.logging.error('Cannot restore checkpoint: %s', e)
continue
if not (ckpt_state and ckpt_state.model_checkpoint_path):
tf.logging.info('No model to eval yet at %s', FLAGS.train_dir)
continue
tf.logging.info('Loading checkpoint %s', ckpt_state.model_checkpoint_path)
saver.restore(sess, ckpt_state.model_checkpoint_path)
(article_batch, abstract_batch, targets, article_lens, abstract_lens,
loss_weights, _, _) = data_batcher.NextBatch()
(summaries, loss, train_step) = model.run_eval_step(
sess, article_batch, abstract_batch, targets, article_lens,
abstract_lens, loss_weights)
tf.logging.info(
'article: %s',
' '.join(data.Ids2Words(article_batch[0][:].tolist(), vocab)))
tf.logging.info(
'abstract: %s',
' '.join(data.Ids2Words(abstract_batch[0][:].tolist(), vocab)))
summary_writer.add_summary(summaries, train_step)
running_avg_loss = _RunningAvgLoss(
running_avg_loss, loss, summary_writer, train_step)
if step % 100 == 0:
summary_writer.flush()
def main(unused_argv):
vocab = data.Vocab(FLAGS.vocab_path, 1000000)
# Check for presence of required special tokens.
assert vocab.WordToId(data.PAD_TOKEN) > 0
assert vocab.WordToId(data.UNKNOWN_TOKEN) >= 0
assert vocab.WordToId(data.SENTENCE_START) > 0
assert vocab.WordToId(data.SENTENCE_END) > 0
batch_size = 4
if FLAGS.mode == 'decode':
batch_size = FLAGS.beam_size
hps = seq2seq_attention_model.HParams(
mode=FLAGS.mode, # train, eval, decode
min_lr=0.01, # min learning rate.
lr=0.15, # learning rate
batch_size=batch_size,
enc_layers=4,
enc_timesteps=120,
dec_timesteps=30,
min_input_len=2, # discard articles/summaries < than this
num_hidden=256, # for rnn cell
emb_dim=128, # If 0, don't use embedding
max_grad_norm=2,
num_softmax_samples=4096) # If 0, no sampled softmax.
batcher = batch_reader.Batcher(
FLAGS.data_path, vocab, hps, FLAGS.article_key,
FLAGS.abstract_key, FLAGS.max_article_sentences,
FLAGS.max_abstract_sentences, bucketing=FLAGS.use_bucketing,
truncate_input=FLAGS.truncate_input)
tf.set_random_seed(FLAGS.random_seed)
if hps.mode == 'train':
model = seq2seq_attention_model.Seq2SeqAttentionModel(
hps, vocab, num_gpus=FLAGS.num_gpus)
_Train(model, batcher)
elif hps.mode == 'eval':
model = seq2seq_attention_model.Seq2SeqAttentionModel(
hps, vocab, num_gpus=FLAGS.num_gpus)
_Eval(model, batcher, vocab=vocab)
elif hps.mode == 'decode':
decode_mdl_hps = hps
# Only need to restore the 1st step and reuse it since
# we keep and feed in state for each step's output.
decode_mdl_hps = hps._replace(dec_timesteps=1)
model = seq2seq_attention_model.Seq2SeqAttentionModel(
decode_mdl_hps, vocab, num_gpus=FLAGS.num_gpus)
decoder = seq2seq_attention_decode.BSDecoder(model, batcher, hps, vocab)
decoder.DecodeLoop()
if __name__ == '__main__':
tf.app.run()
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Module for decoding."""
import os
import time
import tensorflow as tf
import beam_search
import data
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_integer('max_decode_steps', 1000000,
'Number of decoding steps.')
tf.app.flags.DEFINE_integer('decode_batches_per_ckpt', 8000,
'Number of batches to decode before restoring next '
'checkpoint')
DECODE_LOOP_DELAY_SECS = 60
DECODE_IO_FLUSH_INTERVAL = 100
class DecodeIO(object):
"""Writes the decoded and references to RKV files for Rouge score.
See nlp/common/utils/internal/rkv_parser.py for detail about rkv file.
"""
def __init__(self, outdir):
self._cnt = 0
self._outdir = outdir
if not os.path.exists(self._outdir):
os.mkdir(self._outdir)
self._ref_file = None
self._decode_file = None
def Write(self, reference, decode):
"""Writes the reference and decoded outputs to RKV files.
Args:
reference: The human (correct) result.
decode: The machine-generated result
"""
self._ref_file.write('output=%s\n' % reference)
self._decode_file.write('output=%s\n' % decode)
self._cnt += 1
if self._cnt % DECODE_IO_FLUSH_INTERVAL == 0:
self._ref_file.flush()
self._decode_file.flush()
def ResetFiles(self):
"""Resets the output files. Must be called once before Write()."""
if self._ref_file: self._ref_file.close()
if self._decode_file: self._decode_file.close()
timestamp = int(time.time())
self._ref_file = open(
os.path.join(self._outdir, 'ref%d'%timestamp), 'w')
self._decode_file = open(
os.path.join(self._outdir, 'decode%d'%timestamp), 'w')
class BSDecoder(object):
"""Beam search decoder."""
def __init__(self, model, batch_reader, hps, vocab):
"""Beam search decoding.
Args:
model: The seq2seq attentional model.
batch_reader: The batch data reader.
hps: Hyperparamters.
vocab: Vocabulary
"""
self._model = model
self._model.build_graph()
self._batch_reader = batch_reader
self._hps = hps
self._vocab = vocab
self._saver = tf.train.Saver()
self._decode_io = DecodeIO(FLAGS.decode_dir)
def DecodeLoop(self):
"""Decoding loop for long running process."""
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
step = 0
while step < FLAGS.max_decode_steps:
time.sleep(DECODE_LOOP_DELAY_SECS)
if not self._Decode(self._saver, sess):
continue
step += 1
def _Decode(self, saver, sess):
"""Restore a checkpoint and decode it.
Args:
saver: Tensorflow checkpoint saver.
sess: Tensorflow session.
Returns:
If success, returns true, otherwise, false.
"""
ckpt_state = tf.train.get_checkpoint_state(FLAGS.log_root)
if not (ckpt_state and ckpt_state.model_checkpoint_path):
tf.logging.info('No model to decode yet at %s', FLAGS.log_root)
return False
tf.logging.info('checkpoint path %s', ckpt_state.model_checkpoint_path)
ckpt_path = os.path.join(
FLAGS.log_root, os.path.basename(ckpt_state.model_checkpoint_path))
tf.logging.info('renamed checkpoint path %s', ckpt_path)
saver.restore(sess, ckpt_path)
self._decode_io.ResetFiles()
for _ in xrange(FLAGS.decode_batches_per_ckpt):
(article_batch, _, _, article_lens, _, _, origin_articles,
origin_abstracts) = self._batch_reader.NextBatch()
for i in xrange(self._hps.batch_size):
bs = beam_search.BeamSearch(
self._model, self._hps.batch_size,
self._vocab.WordToId(data.SENTENCE_START),
self._vocab.WordToId(data.SENTENCE_END),
self._hps.dec_timesteps)
article_batch_cp = article_batch.copy()
article_batch_cp[:] = article_batch[i:i+1]
article_lens_cp = article_lens.copy()
article_lens_cp[:] = article_lens[i:i+1]
best_beam = bs.BeamSearch(sess, article_batch_cp, article_lens_cp)[0]
decode_output = [int(t) for t in best_beam.tokens[1:]]
self._DecodeBatch(
origin_articles[i], origin_abstracts[i], decode_output)
return True
def _DecodeBatch(self, article, abstract, output_ids):
"""Convert id to words and writing results.
Args:
article: The original article string.
abstract: The human (correct) abstract string.
output_ids: The abstract word ids output by machine.
"""
decoded_output = ' '.join(data.Ids2Words(output_ids, self._vocab))
end_p = decoded_output.find(data.SENTENCE_END, 0)
if end_p != -1:
decoded_output = decoded_output[:end_p]
tf.logging.info('article: %s', article)
tf.logging.info('abstract: %s', abstract)
tf.logging.info('decoded: %s', decoded_output)
self._decode_io.Write(abstract, decoded_output.strip())
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Sequence-to-Sequence with attention model for text summarization.
"""
from collections import namedtuple
import numpy as np
import tensorflow as tf
import seq2seq_lib
HParams = namedtuple('HParams',
'mode, min_lr, lr, batch_size, '
'enc_layers, enc_timesteps, dec_timesteps, '
'min_input_len, num_hidden, emb_dim, max_grad_norm, '
'num_softmax_samples')
def _extract_argmax_and_embed(embedding, output_projection=None,
update_embedding=True):
"""Get a loop_function that extracts the previous symbol and embeds it.
Args:
embedding: embedding tensor for symbols.
output_projection: None or a pair (W, B). If provided, each fed previous
output will first be multiplied by W and added B.
update_embedding: Boolean; if False, the gradients will not propagate
through the embeddings.
Returns:
A loop function.
"""
def loop_function(prev, _):
"""function that feed previous model output rather than ground truth."""
if output_projection is not None:
prev = tf.nn.xw_plus_b(
prev, output_projection[0], output_projection[1])
prev_symbol = tf.argmax(prev, 1)
# Note that gradients will not propagate through the second parameter of
# embedding_lookup.
emb_prev = tf.nn.embedding_lookup(embedding, prev_symbol)
if not update_embedding:
emb_prev = tf.stop_gradient(emb_prev)
return emb_prev
return loop_function
class Seq2SeqAttentionModel(object):
"""Wrapper for Tensorflow model graph for text sum vectors."""
def __init__(self, hps, vocab, num_gpus=0):
self._hps = hps
self._vocab = vocab
self._num_gpus = num_gpus
self._cur_gpu = 0
def run_train_step(self, sess, article_batch, abstract_batch, targets,
article_lens, abstract_lens, loss_weights):
to_return = [self._train_op, self._summaries, self._loss, self.global_step]
return sess.run(to_return,
feed_dict={self._articles: article_batch,
self._abstracts: abstract_batch,
self._targets: targets,
self._article_lens: article_lens,
self._abstract_lens: abstract_lens,
self._loss_weights: loss_weights})
def run_eval_step(self, sess, article_batch, abstract_batch, targets,
article_lens, abstract_lens, loss_weights):
to_return = [self._summaries, self._loss, self.global_step]
return sess.run(to_return,
feed_dict={self._articles: article_batch,
self._abstracts: abstract_batch,
self._targets: targets,
self._article_lens: article_lens,
self._abstract_lens: abstract_lens,
self._loss_weights: loss_weights})
def run_decode_step(self, sess, article_batch, abstract_batch, targets,
article_lens, abstract_lens, loss_weights):
to_return = [self._outputs, self.global_step]
return sess.run(to_return,
feed_dict={self._articles: article_batch,
self._abstracts: abstract_batch,
self._targets: targets,
self._article_lens: article_lens,
self._abstract_lens: abstract_lens,
self._loss_weights: loss_weights})
def _next_device(self):
"""Round robin the gpu device. (Reserve last gpu for expensive op)."""
if self._num_gpus == 0:
return ''
dev = '/gpu:%d' % self._cur_gpu
self._cur_gpu = (self._cur_gpu + 1) % (self._num_gpus-1)
return dev
def _get_gpu(self, gpu_id):
if self._num_gpus <= 0 or gpu_id >= self._num_gpus:
return ''
return '/gpu:%d' % gpu_id
def _add_placeholders(self):
"""Inputs to be fed to the graph."""
hps = self._hps
self._articles = tf.placeholder(tf.int32,
[hps.batch_size, hps.enc_timesteps],
name='articles')
self._abstracts = tf.placeholder(tf.int32,
[hps.batch_size, hps.dec_timesteps],
name='abstracts')
self._targets = tf.placeholder(tf.int32,
[hps.batch_size, hps.dec_timesteps],
name='targets')
self._article_lens = tf.placeholder(tf.int32, [hps.batch_size],
name='article_lens')
self._abstract_lens = tf.placeholder(tf.int32, [hps.batch_size],
name='abstract_lens')
self._loss_weights = tf.placeholder(tf.float32,
[hps.batch_size, hps.dec_timesteps],
name='loss_weights')
def _add_seq2seq(self):
hps = self._hps
vsize = self._vocab.NumIds()
with tf.variable_scope('seq2seq'):
encoder_inputs = tf.unpack(tf.transpose(self._articles))
decoder_inputs = tf.unpack(tf.transpose(self._abstracts))
targets = tf.unpack(tf.transpose(self._targets))
loss_weights = tf.unpack(tf.transpose(self._loss_weights))
article_lens = self._article_lens
# Embedding shared by the input and outputs.
with tf.variable_scope('embedding'), tf.device('/cpu:0'):
embedding = tf.get_variable(
'embedding', [vsize, hps.emb_dim], dtype=tf.float32,
initializer=tf.truncated_normal_initializer(stddev=1e-4))
emb_encoder_inputs = [tf.nn.embedding_lookup(embedding, x)
for x in encoder_inputs]
emb_decoder_inputs = [tf.nn.embedding_lookup(embedding, x)
for x in decoder_inputs]
for layer_i in xrange(hps.enc_layers):
with tf.variable_scope('encoder%d'%layer_i), tf.device(
self._next_device()):
cell_fw = tf.nn.rnn_cell.LSTMCell(
hps.num_hidden,
initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=123))
cell_bw = tf.nn.rnn_cell.LSTMCell(
hps.num_hidden,
initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=113))
(emb_encoder_inputs, fw_state, _) = tf.nn.bidirectional_rnn(
cell_fw, cell_bw, emb_encoder_inputs, dtype=tf.float32,
sequence_length=article_lens)
encoder_outputs = emb_encoder_inputs
with tf.variable_scope('output_projection'):
w = tf.get_variable(
'w', [hps.num_hidden, vsize], dtype=tf.float32,
initializer=tf.truncated_normal_initializer(stddev=1e-4))
w_t = tf.transpose(w)
v = tf.get_variable(
'v', [vsize], dtype=tf.float32,
initializer=tf.truncated_normal_initializer(stddev=1e-4))
with tf.variable_scope('decoder'), tf.device(self._next_device()):
# When decoding, use model output from the previous step
# for the next step.
loop_function = None
if hps.mode == 'decode':
loop_function = _extract_argmax_and_embed(
embedding, (w, v), update_embedding=False)
cell = tf.nn.rnn_cell.LSTMCell(
hps.num_hidden,
initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=113))
encoder_outputs = [tf.reshape(x, [hps.batch_size, 1, 2*hps.num_hidden])
for x in encoder_outputs]
self._enc_top_states = tf.concat(1, encoder_outputs)
self._dec_in_state = fw_state
# During decoding, follow up _dec_in_state are fed from beam_search.
# dec_out_state are stored by beam_search for next step feeding.
initial_state_attention = (hps.mode == 'decode')
decoder_outputs, self._dec_out_state = tf.nn.seq2seq.attention_decoder(
emb_decoder_inputs, self._dec_in_state, self._enc_top_states,
cell, num_heads=1, loop_function=loop_function,
initial_state_attention=initial_state_attention)
with tf.variable_scope('output'), tf.device(self._next_device()):
model_outputs = []
for i in xrange(len(decoder_outputs)):
if i > 0:
tf.get_variable_scope().reuse_variables()
model_outputs.append(
tf.nn.xw_plus_b(decoder_outputs[i], w, v))
if hps.mode == 'decode':
with tf.variable_scope('decode_output'), tf.device('/cpu:0'):
best_outputs = [tf.argmax(x, 1) for x in model_outputs]
tf.logging.info('best_outputs%s', best_outputs[0].get_shape())
self._outputs = tf.concat(
1, [tf.reshape(x, [hps.batch_size, 1]) for x in best_outputs])
self._topk_log_probs, self._topk_ids = tf.nn.top_k(
tf.log(tf.nn.softmax(model_outputs[-1])), hps.batch_size*2)
with tf.variable_scope('loss'), tf.device(self._next_device()):
def sampled_loss_func(inputs, labels):
with tf.device('/cpu:0'): # Try gpu.
labels = tf.reshape(labels, [-1, 1])
return tf.nn.sampled_softmax_loss(w_t, v, inputs, labels,
hps.num_softmax_samples, vsize)
if hps.num_softmax_samples != 0 and hps.mode == 'train':
self._loss = seq2seq_lib.sampled_sequence_loss(
decoder_outputs, targets, loss_weights, sampled_loss_func)
else:
self._loss = tf.nn.seq2seq.sequence_loss(
model_outputs, targets, loss_weights)
tf.scalar_summary('loss', tf.minimum(12.0, self._loss))
def _add_train_op(self):
"""Sets self._train_op, op to run for training."""
hps = self._hps
self._lr_rate = tf.maximum(
hps.min_lr, # min_lr_rate.
tf.train.exponential_decay(hps.lr, self.global_step, 30000, 0.98))
tvars = tf.trainable_variables()
with tf.device(self._get_gpu(self._num_gpus-1)):
grads, global_norm = tf.clip_by_global_norm(
tf.gradients(self._loss, tvars), hps.max_grad_norm)
tf.scalar_summary('global_norm', global_norm)
optimizer = tf.train.GradientDescentOptimizer(self._lr_rate)
tf.scalar_summary('learning rate', self._lr_rate)
self._train_op = optimizer.apply_gradients(
zip(grads, tvars), global_step=self.global_step, name='train_step')
def encode_top_state(self, sess, enc_inputs, enc_len):
"""Return the top states from encoder for decoder.
Args:
sess: tensorflow session.
enc_inputs: encoder inputs of shape [batch_size, enc_timesteps].
enc_len: encoder input length of shape [batch_size]
Returns:
enc_top_states: The top level encoder states.
dec_in_state: The decoder layer initial state.
"""
results = sess.run([self._enc_top_states, self._dec_in_state],
feed_dict={self._articles: enc_inputs,
self._article_lens: enc_len})
return results[0], results[1][0]
def decode_topk(self, sess, latest_tokens, enc_top_states, dec_init_states):
"""Return the topK results and new decoder states."""
feed = {
self._enc_top_states: enc_top_states,
self._dec_in_state:
np.squeeze(np.array(dec_init_states)),
self._abstracts:
np.transpose(np.array([latest_tokens])),
self._abstract_lens: np.ones([len(dec_init_states)], np.int32)}
results = sess.run(
[self._topk_ids, self._topk_log_probs, self._dec_out_state],
feed_dict=feed)
ids, probs, states = results[0], results[1], results[2]
new_states = [s for s in states]
return ids, probs, new_states
def build_graph(self):
self._add_placeholders()
self._add_seq2seq()
self.global_step = tf.Variable(0, name='global_step', trainable=False)
if self._hps.mode == 'train':
self._add_train_op()
self._summaries = tf.merge_all_summaries()
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""seq2seq library codes copied from elsewhere for customization."""
import tensorflow as tf
# Adapted to support sampled_softmax loss function, which accepts activations
# instead of logits.
def sequence_loss_by_example(inputs, targets, weights, loss_function,
average_across_timesteps=True, name=None):
"""Sampled softmax loss for a sequence of inputs (per example).
Args:
inputs: List of 2D Tensors of shape [batch_size x hid_dim].
targets: List of 1D batch-sized int32 Tensors of the same length as logits.
weights: List of 1D batch-sized float-Tensors of the same length as logits.
loss_function: Sampled softmax function (inputs, labels) -> loss
average_across_timesteps: If set, divide the returned cost by the total
label weight.
name: Optional name for this operation, default: 'sequence_loss_by_example'.
Returns:
1D batch-sized float Tensor: The log-perplexity for each sequence.
Raises:
ValueError: If len(inputs) is different from len(targets) or len(weights).
"""
if len(targets) != len(inputs) or len(weights) != len(inputs):
raise ValueError('Lengths of logits, weights, and targets must be the same '
'%d, %d, %d.' % (len(inputs), len(weights), len(targets)))
with tf.op_scope(inputs + targets + weights, name,
'sequence_loss_by_example'):
log_perp_list = []
for inp, target, weight in zip(inputs, targets, weights):
crossent = loss_function(inp, target)
log_perp_list.append(crossent * weight)
log_perps = tf.add_n(log_perp_list)
if average_across_timesteps:
total_size = tf.add_n(weights)
total_size += 1e-12 # Just to avoid division by 0 for all-0 weights.
log_perps /= total_size
return log_perps
def sampled_sequence_loss(inputs, targets, weights, loss_function,
average_across_timesteps=True,
average_across_batch=True, name=None):
"""Weighted cross-entropy loss for a sequence of logits, batch-collapsed.
Args:
inputs: List of 2D Tensors of shape [batch_size x hid_dim].
targets: List of 1D batch-sized int32 Tensors of the same length as inputs.
weights: List of 1D batch-sized float-Tensors of the same length as inputs.
loss_function: Sampled softmax function (inputs, labels) -> loss
average_across_timesteps: If set, divide the returned cost by the total
label weight.
average_across_batch: If set, divide the returned cost by the batch size.
name: Optional name for this operation, defaults to 'sequence_loss'.
Returns:
A scalar float Tensor: The average log-perplexity per symbol (weighted).
Raises:
ValueError: If len(inputs) is different from len(targets) or len(weights).
"""
with tf.op_scope(inputs + targets + weights, name, 'sampled_sequence_loss'):
cost = tf.reduce_sum(sequence_loss_by_example(
inputs, targets, weights, loss_function,
average_across_timesteps=average_across_timesteps))
if average_across_batch:
batch_size = tf.shape(targets[0])[0]
return cost / tf.cast(batch_size, tf.float32)
else:
return cost
def linear(args, output_size, bias, bias_start=0.0, scope=None):
"""Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.
Args:
args: a 2D Tensor or a list of 2D, batch x n, Tensors.
output_size: int, second dimension of W[i].
bias: boolean, whether to add a bias term or not.
bias_start: starting value to initialize the bias; 0 by default.
scope: VariableScope for the created subgraph; defaults to "Linear".
Returns:
A 2D Tensor with shape [batch x output_size] equal to
sum_i(args[i] * W[i]), where W[i]s are newly created matrices.
Raises:
ValueError: if some of the arguments has unspecified or wrong shape.
"""
if args is None or (isinstance(args, (list, tuple)) and not args):
raise ValueError('`args` must be specified')
if not isinstance(args, (list, tuple)):
args = [args]
# Calculate the total size of arguments on dimension 1.
total_arg_size = 0
shapes = [a.get_shape().as_list() for a in args]
for shape in shapes:
if len(shape) != 2:
raise ValueError('Linear is expecting 2D arguments: %s' % str(shapes))
if not shape[1]:
raise ValueError('Linear expects shape[1] of arguments: %s' % str(shapes))
else:
total_arg_size += shape[1]
# Now the computation.
with tf.variable_scope(scope or 'Linear'):
matrix = tf.get_variable('Matrix', [total_arg_size, output_size])
if len(args) == 1:
res = tf.matmul(args[0], matrix)
else:
res = tf.matmul(tf.concat(1, args), matrix)
if not bias:
return res
bias_term = tf.get_variable(
'Bias', [output_size],
initializer=tf.constant_initializer(bias_start))
return res + bias_term
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment