Commit 124a501d authored by Quoc Le's avatar Quoc Le
Browse files

fix the readme

parent 33c4e784
...@@ -20,3 +20,22 @@ To propose a model for inclusion please submit a pull request. ...@@ -20,3 +20,22 @@ To propose a model for inclusion please submit a pull request.
- [textsum](textsum) -- sequence-to-sequence with attention model for text summarization. - [textsum](textsum) -- sequence-to-sequence with attention model for text summarization.
- [transformer](transformer) -- spatial transformer network, which allows the spatial manipulation of data within the network - [transformer](transformer) -- spatial transformer network, which allows the spatial manipulation of data within the network
- [im2txt](im2txt) -- image-to-text neural network for image captioning. - [im2txt](im2txt) -- image-to-text neural network for image captioning.
=======
Implementation of the Neural Programmer model described in https://openreview.net/pdf?id=ry2YOrcge
Download the data from http://www-nlp.stanford.edu/software/sempre/wikitable/
Change the data_dir FLAG to the location of the data
Training:
python neural_programmer.py
The models are written to FLAGS.output_dir
Testing:
python neural_programmer.py --evaluator_job=True
The models are loaded from FLAGS.output_dir.
The evaluation is done on development data.
Maintained by Arvind Neelakantan (arvind2505)
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions for constructing vocabulary, converting the examples to integer format and building the required masks for batch computation Author: aneelakantan (Arvind Neelakantan)
"""
import copy
import numbers
import numpy as np
import wiki_data
def return_index(a):
for i in range(len(a)):
if (a[i] == 1.0):
return i
def construct_vocab(data, utility, add_word=False):
ans = []
for example in data:
sent = ""
for word in example.question:
if (not (isinstance(word, numbers.Number))):
sent += word + " "
example.original_nc = copy.deepcopy(example.number_columns)
example.original_wc = copy.deepcopy(example.word_columns)
example.original_nc_names = copy.deepcopy(example.number_column_names)
example.original_wc_names = copy.deepcopy(example.word_column_names)
if (add_word):
continue
number_found = 0
if (not (example.is_bad_example)):
for word in example.question:
if (isinstance(word, numbers.Number)):
number_found += 1
else:
if (not (utility.word_ids.has_key(word))):
utility.words.append(word)
utility.word_count[word] = 1
utility.word_ids[word] = len(utility.word_ids)
utility.reverse_word_ids[utility.word_ids[word]] = word
else:
utility.word_count[word] += 1
for col_name in example.word_column_names:
for word in col_name:
if (isinstance(word, numbers.Number)):
number_found += 1
else:
if (not (utility.word_ids.has_key(word))):
utility.words.append(word)
utility.word_count[word] = 1
utility.word_ids[word] = len(utility.word_ids)
utility.reverse_word_ids[utility.word_ids[word]] = word
else:
utility.word_count[word] += 1
for col_name in example.number_column_names:
for word in col_name:
if (isinstance(word, numbers.Number)):
number_found += 1
else:
if (not (utility.word_ids.has_key(word))):
utility.words.append(word)
utility.word_count[word] = 1
utility.word_ids[word] = len(utility.word_ids)
utility.reverse_word_ids[utility.word_ids[word]] = word
else:
utility.word_count[word] += 1
def word_lookup(word, utility):
if (utility.word_ids.has_key(word)):
return word
else:
return utility.unk_token
def convert_to_int_2d_and_pad(a, utility):
ans = []
#print a
for b in a:
temp = []
if (len(b) > utility.FLAGS.max_entry_length):
b = b[0:utility.FLAGS.max_entry_length]
for remaining in range(len(b), utility.FLAGS.max_entry_length):
b.append(utility.dummy_token)
assert len(b) == utility.FLAGS.max_entry_length
for word in b:
temp.append(utility.word_ids[word_lookup(word, utility)])
ans.append(temp)
#print ans
return ans
def convert_to_bool_and_pad(a, utility):
a = a.tolist()
for i in range(len(a)):
for j in range(len(a[i])):
if (a[i][j] < 1):
a[i][j] = False
else:
a[i][j] = True
a[i] = a[i] + [False] * (utility.FLAGS.max_elements - len(a[i]))
return a
seen_tables = {}
def partial_match(question, table, number):
answer = []
match = {}
for i in range(len(table)):
temp = []
for j in range(len(table[i])):
temp.append(0)
answer.append(temp)
for i in range(len(table)):
for j in range(len(table[i])):
for word in question:
if (number):
if (word == table[i][j]):
answer[i][j] = 1.0
match[i] = 1.0
else:
if (word in table[i][j]):
answer[i][j] = 1.0
match[i] = 1.0
return answer, match
def exact_match(question, table, number):
#performs exact match operation
answer = []
match = {}
matched_indices = []
for i in range(len(table)):
temp = []
for j in range(len(table[i])):
temp.append(0)
answer.append(temp)
for i in range(len(table)):
for j in range(len(table[i])):
if (number):
for word in question:
if (word == table[i][j]):
match[i] = 1.0
answer[i][j] = 1.0
else:
table_entry = table[i][j]
for k in range(len(question)):
if (k + len(table_entry) <= len(question)):
if (table_entry == question[k:(k + len(table_entry))]):
#if(len(table_entry) == 1):
#print "match: ", table_entry, question
match[i] = 1.0
answer[i][j] = 1.0
matched_indices.append((k, len(table_entry)))
return answer, match, matched_indices
def partial_column_match(question, table, number):
answer = []
for i in range(len(table)):
answer.append(0)
for i in range(len(table)):
for word in question:
if (word in table[i]):
answer[i] = 1.0
return answer
def exact_column_match(question, table, number):
#performs exact match on column names
answer = []
matched_indices = []
for i in range(len(table)):
answer.append(0)
for i in range(len(table)):
table_entry = table[i]
for k in range(len(question)):
if (k + len(table_entry) <= len(question)):
if (table_entry == question[k:(k + len(table_entry))]):
answer[i] = 1.0
matched_indices.append((k, len(table_entry)))
return answer, matched_indices
def get_max_entry(a):
e = {}
for w in a:
if (w != "UNK, "):
if (e.has_key(w)):
e[w] += 1
else:
e[w] = 1
if (len(e) > 0):
(key, val) = sorted(e.items(), key=lambda x: -1 * x[1])[0]
if (val > 1):
return key
else:
return -1.0
else:
return -1.0
def list_join(a):
ans = ""
for w in a:
ans += str(w) + ", "
return ans
def group_by_max(table, number):
#computes the most frequently occuring entry in a column
answer = []
for i in range(len(table)):
temp = []
for j in range(len(table[i])):
temp.append(0)
answer.append(temp)
for i in range(len(table)):
if (number):
curr = table[i]
else:
curr = [list_join(w) for w in table[i]]
max_entry = get_max_entry(curr)
#print i, max_entry
for j in range(len(curr)):
if (max_entry == curr[j]):
answer[i][j] = 1.0
else:
answer[i][j] = 0.0
return answer
def pick_one(a):
for i in range(len(a)):
if (1.0 in a[i]):
return True
return False
def check_processed_cols(col, utility):
return True in [
True for y in col
if (y != utility.FLAGS.pad_int and y !=
utility.FLAGS.bad_number_pre_process)
]
def complete_wiki_processing(data, utility, train=True):
#convert to integers and padding
processed_data = []
num_bad_examples = 0
for example in data:
number_found = 0
if (example.is_bad_example):
num_bad_examples += 1
if (not (example.is_bad_example)):
example.string_question = example.question[:]
#entry match
example.processed_number_columns = example.processed_number_columns[:]
example.processed_word_columns = example.processed_word_columns[:]
example.word_exact_match, word_match, matched_indices = exact_match(
example.string_question, example.original_wc, number=False)
example.number_exact_match, number_match, _ = exact_match(
example.string_question, example.original_nc, number=True)
if (not (pick_one(example.word_exact_match)) and not (
pick_one(example.number_exact_match))):
assert len(word_match) == 0
assert len(number_match) == 0
example.word_exact_match, word_match = partial_match(
example.string_question, example.original_wc, number=False)
#group by max
example.word_group_by_max = group_by_max(example.original_wc, False)
example.number_group_by_max = group_by_max(example.original_nc, True)
#column name match
example.word_column_exact_match, wcol_matched_indices = exact_column_match(
example.string_question, example.original_wc_names, number=False)
example.number_column_exact_match, ncol_matched_indices = exact_column_match(
example.string_question, example.original_nc_names, number=False)
if (not (1.0 in example.word_column_exact_match) and not (
1.0 in example.number_column_exact_match)):
example.word_column_exact_match = partial_column_match(
example.string_question, example.original_wc_names, number=False)
example.number_column_exact_match = partial_column_match(
example.string_question, example.original_nc_names, number=False)
if (len(word_match) > 0 or len(number_match) > 0):
example.question.append(utility.entry_match_token)
if (1.0 in example.word_column_exact_match or
1.0 in example.number_column_exact_match):
example.question.append(utility.column_match_token)
example.string_question = example.question[:]
example.number_lookup_matrix = np.transpose(
example.number_lookup_matrix)[:]
example.word_lookup_matrix = np.transpose(example.word_lookup_matrix)[:]
example.columns = example.number_columns[:]
example.word_columns = example.word_columns[:]
example.len_total_cols = len(example.word_column_names) + len(
example.number_column_names)
example.column_names = example.number_column_names[:]
example.word_column_names = example.word_column_names[:]
example.string_column_names = example.number_column_names[:]
example.string_word_column_names = example.word_column_names[:]
example.sorted_number_index = []
example.sorted_word_index = []
example.column_mask = []
example.word_column_mask = []
example.processed_column_mask = []
example.processed_word_column_mask = []
example.word_column_entry_mask = []
example.question_attention_mask = []
example.question_number = example.question_number_1 = -1
example.question_attention_mask = []
example.ordinal_question = []
example.ordinal_question_one = []
new_question = []
if (len(example.number_columns) > 0):
example.len_col = len(example.number_columns[0])
else:
example.len_col = len(example.word_columns[0])
for (start, length) in matched_indices:
for j in range(length):
example.question[start + j] = utility.unk_token
#print example.question
for word in example.question:
if (isinstance(word, numbers.Number) or wiki_data.is_date(word)):
if (not (isinstance(word, numbers.Number)) and
wiki_data.is_date(word)):
word = word.replace("X", "").replace("-", "")
number_found += 1
if (number_found == 1):
example.question_number = word
if (len(example.ordinal_question) > 0):
example.ordinal_question[len(example.ordinal_question) - 1] = 1.0
else:
example.ordinal_question.append(1.0)
elif (number_found == 2):
example.question_number_1 = word
if (len(example.ordinal_question_one) > 0):
example.ordinal_question_one[len(example.ordinal_question_one) -
1] = 1.0
else:
example.ordinal_question_one.append(1.0)
else:
new_question.append(word)
example.ordinal_question.append(0.0)
example.ordinal_question_one.append(0.0)
example.question = [
utility.word_ids[word_lookup(w, utility)] for w in new_question
]
example.question_attention_mask = [0.0] * len(example.question)
#when the first question number occurs before a word
example.ordinal_question = example.ordinal_question[0:len(
example.question)]
example.ordinal_question_one = example.ordinal_question_one[0:len(
example.question)]
#question-padding
example.question = [utility.word_ids[utility.dummy_token]] * (
utility.FLAGS.question_length - len(example.question)
) + example.question
example.question_attention_mask = [-10000.0] * (
utility.FLAGS.question_length - len(example.question_attention_mask)
) + example.question_attention_mask
example.ordinal_question = [0.0] * (utility.FLAGS.question_length -
len(example.ordinal_question)
) + example.ordinal_question
example.ordinal_question_one = [0.0] * (utility.FLAGS.question_length -
len(example.ordinal_question_one)
) + example.ordinal_question_one
if (True):
#number columns and related-padding
num_cols = len(example.columns)
start = 0
for column in example.number_columns:
if (check_processed_cols(example.processed_number_columns[start],
utility)):
example.processed_column_mask.append(0.0)
sorted_index = sorted(
range(len(example.processed_number_columns[start])),
key=lambda k: example.processed_number_columns[start][k],
reverse=True)
sorted_index = sorted_index + [utility.FLAGS.pad_int] * (
utility.FLAGS.max_elements - len(sorted_index))
example.sorted_number_index.append(sorted_index)
example.columns[start] = column + [utility.FLAGS.pad_int] * (
utility.FLAGS.max_elements - len(column))
example.processed_number_columns[start] += [utility.FLAGS.pad_int] * (
utility.FLAGS.max_elements -
len(example.processed_number_columns[start]))
start += 1
example.column_mask.append(0.0)
for remaining in range(num_cols, utility.FLAGS.max_number_cols):
example.sorted_number_index.append([utility.FLAGS.pad_int] *
(utility.FLAGS.max_elements))
example.columns.append([utility.FLAGS.pad_int] *
(utility.FLAGS.max_elements))
example.processed_number_columns.append([utility.FLAGS.pad_int] *
(utility.FLAGS.max_elements))
example.number_exact_match.append([0.0] *
(utility.FLAGS.max_elements))
example.number_group_by_max.append([0.0] *
(utility.FLAGS.max_elements))
example.column_mask.append(-100000000.0)
example.processed_column_mask.append(-100000000.0)
example.number_column_exact_match.append(0.0)
example.column_names.append([utility.dummy_token])
#word column and related-padding
start = 0
word_num_cols = len(example.word_columns)
for column in example.word_columns:
if (check_processed_cols(example.processed_word_columns[start],
utility)):
example.processed_word_column_mask.append(0.0)
sorted_index = sorted(
range(len(example.processed_word_columns[start])),
key=lambda k: example.processed_word_columns[start][k],
reverse=True)
sorted_index = sorted_index + [utility.FLAGS.pad_int] * (
utility.FLAGS.max_elements - len(sorted_index))
example.sorted_word_index.append(sorted_index)
column = convert_to_int_2d_and_pad(column, utility)
example.word_columns[start] = column + [[
utility.word_ids[utility.dummy_token]
] * utility.FLAGS.max_entry_length] * (utility.FLAGS.max_elements -
len(column))
example.processed_word_columns[start] += [utility.FLAGS.pad_int] * (
utility.FLAGS.max_elements -
len(example.processed_word_columns[start]))
example.word_column_entry_mask.append([0] * len(column) + [
utility.word_ids[utility.dummy_token]
] * (utility.FLAGS.max_elements - len(column)))
start += 1
example.word_column_mask.append(0.0)
for remaining in range(word_num_cols, utility.FLAGS.max_word_cols):
example.sorted_word_index.append([utility.FLAGS.pad_int] *
(utility.FLAGS.max_elements))
example.word_columns.append([[utility.word_ids[utility.dummy_token]] *
utility.FLAGS.max_entry_length] *
(utility.FLAGS.max_elements))
example.word_column_entry_mask.append(
[utility.word_ids[utility.dummy_token]] *
(utility.FLAGS.max_elements))
example.word_exact_match.append([0.0] * (utility.FLAGS.max_elements))
example.word_group_by_max.append([0.0] * (utility.FLAGS.max_elements))
example.processed_word_columns.append([utility.FLAGS.pad_int] *
(utility.FLAGS.max_elements))
example.word_column_mask.append(-100000000.0)
example.processed_word_column_mask.append(-100000000.0)
example.word_column_exact_match.append(0.0)
example.word_column_names.append([utility.dummy_token] *
utility.FLAGS.max_entry_length)
seen_tables[example.table_key] = 1
#convert column and word column names to integers
example.column_ids = convert_to_int_2d_and_pad(example.column_names,
utility)
example.word_column_ids = convert_to_int_2d_and_pad(
example.word_column_names, utility)
for i_em in range(len(example.number_exact_match)):
example.number_exact_match[i_em] = example.number_exact_match[
i_em] + [0.0] * (utility.FLAGS.max_elements -
len(example.number_exact_match[i_em]))
example.number_group_by_max[i_em] = example.number_group_by_max[
i_em] + [0.0] * (utility.FLAGS.max_elements -
len(example.number_group_by_max[i_em]))
for i_em in range(len(example.word_exact_match)):
example.word_exact_match[i_em] = example.word_exact_match[
i_em] + [0.0] * (utility.FLAGS.max_elements -
len(example.word_exact_match[i_em]))
example.word_group_by_max[i_em] = example.word_group_by_max[
i_em] + [0.0] * (utility.FLAGS.max_elements -
len(example.word_group_by_max[i_em]))
example.exact_match = example.number_exact_match + example.word_exact_match
example.group_by_max = example.number_group_by_max + example.word_group_by_max
example.exact_column_match = example.number_column_exact_match + example.word_column_exact_match
#answer and related mask, padding
if (example.is_lookup):
example.answer = example.calc_answer
example.number_print_answer = example.number_lookup_matrix.tolist()
example.word_print_answer = example.word_lookup_matrix.tolist()
for i_answer in range(len(example.number_print_answer)):
example.number_print_answer[i_answer] = example.number_print_answer[
i_answer] + [0.0] * (utility.FLAGS.max_elements -
len(example.number_print_answer[i_answer]))
for i_answer in range(len(example.word_print_answer)):
example.word_print_answer[i_answer] = example.word_print_answer[
i_answer] + [0.0] * (utility.FLAGS.max_elements -
len(example.word_print_answer[i_answer]))
example.number_lookup_matrix = convert_to_bool_and_pad(
example.number_lookup_matrix, utility)
example.word_lookup_matrix = convert_to_bool_and_pad(
example.word_lookup_matrix, utility)
for remaining in range(num_cols, utility.FLAGS.max_number_cols):
example.number_lookup_matrix.append([False] *
utility.FLAGS.max_elements)
example.number_print_answer.append([0.0] * utility.FLAGS.max_elements)
for remaining in range(word_num_cols, utility.FLAGS.max_word_cols):
example.word_lookup_matrix.append([False] *
utility.FLAGS.max_elements)
example.word_print_answer.append([0.0] * utility.FLAGS.max_elements)
example.print_answer = example.number_print_answer + example.word_print_answer
else:
example.answer = example.calc_answer
example.print_answer = [[0.0] * (utility.FLAGS.max_elements)] * (
utility.FLAGS.max_number_cols + utility.FLAGS.max_word_cols)
#question_number masks
if (example.question_number == -1):
example.question_number_mask = np.zeros([utility.FLAGS.max_elements])
else:
example.question_number_mask = np.ones([utility.FLAGS.max_elements])
if (example.question_number_1 == -1):
example.question_number_one_mask = -10000.0
else:
example.question_number_one_mask = np.float64(0.0)
if (example.len_col > utility.FLAGS.max_elements):
continue
processed_data.append(example)
return processed_data
def add_special_words(utility):
utility.words.append(utility.entry_match_token)
utility.word_ids[utility.entry_match_token] = len(utility.word_ids)
utility.reverse_word_ids[utility.word_ids[
utility.entry_match_token]] = utility.entry_match_token
utility.entry_match_token_id = utility.word_ids[utility.entry_match_token]
print "entry match token: ", utility.word_ids[
utility.entry_match_token], utility.entry_match_token_id
utility.words.append(utility.column_match_token)
utility.word_ids[utility.column_match_token] = len(utility.word_ids)
utility.reverse_word_ids[utility.word_ids[
utility.column_match_token]] = utility.column_match_token
utility.column_match_token_id = utility.word_ids[utility.column_match_token]
print "entry match token: ", utility.word_ids[
utility.column_match_token], utility.column_match_token_id
utility.words.append(utility.dummy_token)
utility.word_ids[utility.dummy_token] = len(utility.word_ids)
utility.reverse_word_ids[utility.word_ids[
utility.dummy_token]] = utility.dummy_token
utility.dummy_token_id = utility.word_ids[utility.dummy_token]
utility.words.append(utility.unk_token)
utility.word_ids[utility.unk_token] = len(utility.word_ids)
utility.reverse_word_ids[utility.word_ids[
utility.unk_token]] = utility.unk_token
def perform_word_cutoff(utility):
if (utility.FLAGS.word_cutoff > 0):
for word in utility.word_ids.keys():
if (utility.word_count.has_key(word) and utility.word_count[word] <
utility.FLAGS.word_cutoff and word != utility.unk_token and
word != utility.dummy_token and word != utility.entry_match_token and
word != utility.column_match_token):
utility.word_ids.pop(word)
utility.words.remove(word)
def word_dropout(question, utility):
if (utility.FLAGS.word_dropout_prob > 0.0):
new_question = []
for i in range(len(question)):
if (question[i] != utility.dummy_token_id and
utility.random.random() > utility.FLAGS.word_dropout_prob):
new_question.append(utility.word_ids[utility.unk_token])
else:
new_question.append(question[i])
return new_question
else:
return question
def generate_feed_dict(data, curr, batch_size, gr, train=False, utility=None):
#prepare feed dict dictionary
feed_dict = {}
feed_examples = []
for j in range(batch_size):
feed_examples.append(data[curr + j])
if (train):
feed_dict[gr.batch_question] = [
word_dropout(feed_examples[j].question, utility)
for j in range(batch_size)
]
else:
feed_dict[gr.batch_question] = [
feed_examples[j].question for j in range(batch_size)
]
feed_dict[gr.batch_question_attention_mask] = [
feed_examples[j].question_attention_mask for j in range(batch_size)
]
feed_dict[
gr.batch_answer] = [feed_examples[j].answer for j in range(batch_size)]
feed_dict[gr.batch_number_column] = [
feed_examples[j].columns for j in range(batch_size)
]
feed_dict[gr.batch_processed_number_column] = [
feed_examples[j].processed_number_columns for j in range(batch_size)
]
feed_dict[gr.batch_processed_sorted_index_number_column] = [
feed_examples[j].sorted_number_index for j in range(batch_size)
]
feed_dict[gr.batch_processed_sorted_index_word_column] = [
feed_examples[j].sorted_word_index for j in range(batch_size)
]
feed_dict[gr.batch_question_number] = np.array(
[feed_examples[j].question_number for j in range(batch_size)]).reshape(
(batch_size, 1))
feed_dict[gr.batch_question_number_one] = np.array(
[feed_examples[j].question_number_1 for j in range(batch_size)]).reshape(
(batch_size, 1))
feed_dict[gr.batch_question_number_mask] = [
feed_examples[j].question_number_mask for j in range(batch_size)
]
feed_dict[gr.batch_question_number_one_mask] = np.array(
[feed_examples[j].question_number_one_mask for j in range(batch_size)
]).reshape((batch_size, 1))
feed_dict[gr.batch_print_answer] = [
feed_examples[j].print_answer for j in range(batch_size)
]
feed_dict[gr.batch_exact_match] = [
feed_examples[j].exact_match for j in range(batch_size)
]
feed_dict[gr.batch_group_by_max] = [
feed_examples[j].group_by_max for j in range(batch_size)
]
feed_dict[gr.batch_column_exact_match] = [
feed_examples[j].exact_column_match for j in range(batch_size)
]
feed_dict[gr.batch_ordinal_question] = [
feed_examples[j].ordinal_question for j in range(batch_size)
]
feed_dict[gr.batch_ordinal_question_one] = [
feed_examples[j].ordinal_question_one for j in range(batch_size)
]
feed_dict[gr.batch_number_column_mask] = [
feed_examples[j].column_mask for j in range(batch_size)
]
feed_dict[gr.batch_number_column_names] = [
feed_examples[j].column_ids for j in range(batch_size)
]
feed_dict[gr.batch_processed_word_column] = [
feed_examples[j].processed_word_columns for j in range(batch_size)
]
feed_dict[gr.batch_word_column_mask] = [
feed_examples[j].word_column_mask for j in range(batch_size)
]
feed_dict[gr.batch_word_column_names] = [
feed_examples[j].word_column_ids for j in range(batch_size)
]
feed_dict[gr.batch_word_column_entry_mask] = [
feed_examples[j].word_column_entry_mask for j in range(batch_size)
]
return feed_dict
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Author: aneelakantan (Arvind Neelakantan)
"""
import numpy as np
import tensorflow as tf
import nn_utils
class Graph():
def __init__(self, utility, batch_size, max_passes, mode="train"):
self.utility = utility
self.data_type = self.utility.tf_data_type[self.utility.FLAGS.data_type]
self.max_elements = self.utility.FLAGS.max_elements
max_elements = self.utility.FLAGS.max_elements
self.num_cols = self.utility.FLAGS.max_number_cols
self.num_word_cols = self.utility.FLAGS.max_word_cols
self.question_length = self.utility.FLAGS.question_length
self.batch_size = batch_size
self.max_passes = max_passes
self.mode = mode
self.embedding_dims = self.utility.FLAGS.embedding_dims
#input question and a mask
self.batch_question = tf.placeholder(tf.int32,
[batch_size, self.question_length])
self.batch_question_attention_mask = tf.placeholder(
self.data_type, [batch_size, self.question_length])
#ground truth scalar answer and lookup answer
self.batch_answer = tf.placeholder(self.data_type, [batch_size])
self.batch_print_answer = tf.placeholder(
self.data_type,
[batch_size, self.num_cols + self.num_word_cols, max_elements])
#number columns and its processed version
self.batch_number_column = tf.placeholder(
self.data_type, [batch_size, self.num_cols, max_elements
]) #columns with numeric entries
self.batch_processed_number_column = tf.placeholder(
self.data_type, [batch_size, self.num_cols, max_elements])
self.batch_processed_sorted_index_number_column = tf.placeholder(
tf.int32, [batch_size, self.num_cols, max_elements])
#word columns and its processed version
self.batch_processed_word_column = tf.placeholder(
self.data_type, [batch_size, self.num_word_cols, max_elements])
self.batch_processed_sorted_index_word_column = tf.placeholder(
tf.int32, [batch_size, self.num_word_cols, max_elements])
self.batch_word_column_entry_mask = tf.placeholder(
tf.int32, [batch_size, self.num_word_cols, max_elements])
#names of word and number columns along with their mask
self.batch_word_column_names = tf.placeholder(
tf.int32,
[batch_size, self.num_word_cols, self.utility.FLAGS.max_entry_length])
self.batch_word_column_mask = tf.placeholder(
self.data_type, [batch_size, self.num_word_cols])
self.batch_number_column_names = tf.placeholder(
tf.int32,
[batch_size, self.num_cols, self.utility.FLAGS.max_entry_length])
self.batch_number_column_mask = tf.placeholder(self.data_type,
[batch_size, self.num_cols])
#exact match and group by max operation
self.batch_exact_match = tf.placeholder(
self.data_type,
[batch_size, self.num_cols + self.num_word_cols, max_elements])
self.batch_column_exact_match = tf.placeholder(
self.data_type, [batch_size, self.num_cols + self.num_word_cols])
self.batch_group_by_max = tf.placeholder(
self.data_type,
[batch_size, self.num_cols + self.num_word_cols, max_elements])
#numbers in the question along with their position. This is used to compute arguments to the comparison operations
self.batch_question_number = tf.placeholder(self.data_type, [batch_size, 1])
self.batch_question_number_one = tf.placeholder(self.data_type,
[batch_size, 1])
self.batch_question_number_mask = tf.placeholder(
self.data_type, [batch_size, max_elements])
self.batch_question_number_one_mask = tf.placeholder(self.data_type,
[batch_size, 1])
self.batch_ordinal_question = tf.placeholder(
self.data_type, [batch_size, self.question_length])
self.batch_ordinal_question_one = tf.placeholder(
self.data_type, [batch_size, self.question_length])
def LSTM_question_embedding(self, sentence, sentence_length):
#LSTM processes the input question
lstm_params = "question_lstm"
hidden_vectors = []
sentence = self.batch_question
question_hidden = tf.zeros(
[self.batch_size, self.utility.FLAGS.embedding_dims], self.data_type)
question_c_hidden = tf.zeros(
[self.batch_size, self.utility.FLAGS.embedding_dims], self.data_type)
if (self.utility.FLAGS.rnn_dropout > 0.0):
if (self.mode == "train"):
rnn_dropout_mask = tf.cast(
tf.random_uniform(
tf.shape(question_hidden), minval=0.0, maxval=1.0) <
self.utility.FLAGS.rnn_dropout,
self.data_type) / self.utility.FLAGS.rnn_dropout
else:
rnn_dropout_mask = tf.ones_like(question_hidden)
for question_iterator in range(self.question_length):
curr_word = sentence[:, question_iterator]
question_vector = nn_utils.apply_dropout(
nn_utils.get_embedding(curr_word, self.utility, self.params),
self.utility.FLAGS.dropout, self.mode)
question_hidden, question_c_hidden = nn_utils.LSTMCell(
question_vector, question_hidden, question_c_hidden, lstm_params,
self.params)
if (self.utility.FLAGS.rnn_dropout > 0.0):
question_hidden = question_hidden * rnn_dropout_mask
hidden_vectors.append(tf.expand_dims(question_hidden, 0))
hidden_vectors = tf.concat(0, hidden_vectors)
return question_hidden, hidden_vectors
def history_recurrent_step(self, curr_hprev, hprev):
#A single RNN step for controller or history RNN
return tf.tanh(
tf.matmul(
tf.concat(1, [hprev, curr_hprev]), self.params[
"history_recurrent"])) + self.params["history_recurrent_bias"]
def question_number_softmax(self, hidden_vectors):
#Attention on quetsion to decide the question number to passed to comparison ops
def compute_ans(op_embedding, comparison):
op_embedding = tf.expand_dims(op_embedding, 0)
#dot product of operation embedding with hidden state to the left of the number occurence
first = tf.transpose(
tf.matmul(op_embedding,
tf.transpose(
tf.reduce_sum(hidden_vectors * tf.tile(
tf.expand_dims(
tf.transpose(self.batch_ordinal_question), 2),
[1, 1, self.utility.FLAGS.embedding_dims]), 0))))
second = self.batch_question_number_one_mask + tf.transpose(
tf.matmul(op_embedding,
tf.transpose(
tf.reduce_sum(hidden_vectors * tf.tile(
tf.expand_dims(
tf.transpose(self.batch_ordinal_question_one), 2
), [1, 1, self.utility.FLAGS.embedding_dims]), 0))))
question_number_softmax = tf.nn.softmax(tf.concat(1, [first, second]))
if (self.mode == "test"):
cond = tf.equal(question_number_softmax,
tf.reshape(
tf.reduce_max(question_number_softmax, 1),
[self.batch_size, 1]))
question_number_softmax = tf.select(
cond,
tf.fill(tf.shape(question_number_softmax), 1.0),
tf.fill(tf.shape(question_number_softmax), 0.0))
question_number_softmax = tf.cast(question_number_softmax,
self.data_type)
ans = tf.reshape(
tf.reduce_sum(question_number_softmax * tf.concat(
1, [self.batch_question_number, self.batch_question_number_one]),
1), [self.batch_size, 1])
return ans
def compute_op_position(op_name):
for i in range(len(self.utility.operations_set)):
if (op_name == self.utility.operations_set[i]):
return i
def compute_question_number(op_name):
op_embedding = tf.nn.embedding_lookup(self.params_unit,
compute_op_position(op_name))
return compute_ans(op_embedding, op_name)
curr_greater_question_number = compute_question_number("greater")
curr_lesser_question_number = compute_question_number("lesser")
curr_geq_question_number = compute_question_number("geq")
curr_leq_question_number = compute_question_number("leq")
return curr_greater_question_number, curr_lesser_question_number, curr_geq_question_number, curr_leq_question_number
def perform_attention(self, context_vector, hidden_vectors, length, mask):
#Performs attention on hiddent_vectors using context vector
context_vector = tf.tile(
tf.expand_dims(context_vector, 0), [length, 1, 1]) #time * bs * d
attention_softmax = tf.nn.softmax(
tf.transpose(tf.reduce_sum(context_vector * hidden_vectors, 2)) +
mask) #batch_size * time
attention_softmax = tf.tile(
tf.expand_dims(tf.transpose(attention_softmax), 2),
[1, 1, self.embedding_dims])
ans_vector = tf.reduce_sum(attention_softmax * hidden_vectors, 0)
return ans_vector
#computes embeddings for column names using parameters of question module
def get_column_hidden_vectors(self):
#vector representations for the column names
self.column_hidden_vectors = tf.reduce_sum(
nn_utils.get_embedding(self.batch_number_column_names, self.utility,
self.params), 2)
self.word_column_hidden_vectors = tf.reduce_sum(
nn_utils.get_embedding(self.batch_word_column_names, self.utility,
self.params), 2)
def create_summary_embeddings(self):
#embeddings for each text entry in the table using parameters of the question module
self.summary_text_entry_embeddings = tf.reduce_sum(
tf.expand_dims(self.batch_exact_match, 3) * tf.expand_dims(
tf.expand_dims(
tf.expand_dims(
nn_utils.get_embedding(self.utility.entry_match_token_id,
self.utility, self.params), 0), 1),
2), 2)
def compute_column_softmax(self, column_controller_vector, time_step):
#compute softmax over all the columns using column controller vector
column_controller_vector = tf.tile(
tf.expand_dims(column_controller_vector, 1),
[1, self.num_cols + self.num_word_cols, 1]) #max_cols * bs * d
column_controller_vector = nn_utils.apply_dropout(
column_controller_vector, self.utility.FLAGS.dropout, self.mode)
self.full_column_hidden_vectors = tf.concat(
1, [self.column_hidden_vectors, self.word_column_hidden_vectors])
self.full_column_hidden_vectors += self.summary_text_entry_embeddings
self.full_column_hidden_vectors = nn_utils.apply_dropout(
self.full_column_hidden_vectors, self.utility.FLAGS.dropout, self.mode)
column_logits = tf.reduce_sum(
column_controller_vector * self.full_column_hidden_vectors, 2) + (
self.params["word_match_feature_column_name"] *
self.batch_column_exact_match) + self.full_column_mask
column_softmax = tf.nn.softmax(column_logits) #batch_size * max_cols
return column_softmax
def compute_first_or_last(self, select, first=True):
#perform first ot last operation on row select with probabilistic row selection
answer = tf.zeros_like(select)
running_sum = tf.zeros([self.batch_size, 1], self.data_type)
for i in range(self.max_elements):
if (first):
current = tf.slice(select, [0, i], [self.batch_size, 1])
else:
current = tf.slice(select, [0, self.max_elements - 1 - i],
[self.batch_size, 1])
curr_prob = current * (1 - running_sum)
curr_prob = curr_prob * tf.cast(curr_prob >= 0.0, self.data_type)
running_sum += curr_prob
temp_ans = []
curr_prob = tf.expand_dims(tf.reshape(curr_prob, [self.batch_size]), 0)
for i_ans in range(self.max_elements):
if (not (first) and i_ans == self.max_elements - 1 - i):
temp_ans.append(curr_prob)
elif (first and i_ans == i):
temp_ans.append(curr_prob)
else:
temp_ans.append(tf.zeros_like(curr_prob))
temp_ans = tf.transpose(tf.concat(0, temp_ans))
answer += temp_ans
return answer
def make_hard_softmax(self, softmax):
#converts soft selection to hard selection. used at test time
cond = tf.equal(
softmax, tf.reshape(tf.reduce_max(softmax, 1), [self.batch_size, 1]))
softmax = tf.select(
cond, tf.fill(tf.shape(softmax), 1.0), tf.fill(tf.shape(softmax), 0.0))
softmax = tf.cast(softmax, self.data_type)
return softmax
def compute_max_or_min(self, select, maxi=True):
#computes the argmax and argmin of a column with probabilistic row selection
answer = tf.zeros([
self.batch_size, self.num_cols + self.num_word_cols, self.max_elements
], self.data_type)
sum_prob = tf.zeros([self.batch_size, self.num_cols + self.num_word_cols],
self.data_type)
for j in range(self.max_elements):
if (maxi):
curr_pos = j
else:
curr_pos = self.max_elements - 1 - j
select_index = tf.slice(self.full_processed_sorted_index_column,
[0, 0, curr_pos], [self.batch_size, -1, 1])
select_mask = tf.equal(
tf.tile(
tf.expand_dims(
tf.tile(
tf.expand_dims(tf.range(self.max_elements), 0),
[self.batch_size, 1]), 1),
[1, self.num_cols + self.num_word_cols, 1]), select_index)
curr_prob = tf.expand_dims(select, 1) * tf.cast(
select_mask, self.data_type) * self.select_bad_number_mask
curr_prob = curr_prob * tf.expand_dims((1 - sum_prob), 2)
curr_prob = curr_prob * tf.expand_dims(
tf.cast((1 - sum_prob) > 0.0, self.data_type), 2)
answer = tf.select(select_mask, curr_prob, answer)
sum_prob += tf.reduce_sum(curr_prob, 2)
return answer
def perform_operations(self, softmax, full_column_softmax, select,
prev_select_1, curr_pass):
#performs all the 15 operations. computes scalar output, lookup answer and row selector
column_softmax = tf.slice(full_column_softmax, [0, 0],
[self.batch_size, self.num_cols])
word_column_softmax = tf.slice(full_column_softmax, [0, self.num_cols],
[self.batch_size, self.num_word_cols])
init_max = self.compute_max_or_min(select, maxi=True)
init_min = self.compute_max_or_min(select, maxi=False)
#operations that are column independent
count = tf.reshape(tf.reduce_sum(select, 1), [self.batch_size, 1])
select_full_column_softmax = tf.tile(
tf.expand_dims(full_column_softmax, 2),
[1, 1, self.max_elements
]) #BS * (max_cols + max_word_cols) * max_elements
select_word_column_softmax = tf.tile(
tf.expand_dims(word_column_softmax, 2),
[1, 1, self.max_elements]) #BS * max_word_cols * max_elements
select_greater = tf.reduce_sum(
self.init_select_greater * select_full_column_softmax,
1) * self.batch_question_number_mask #BS * max_elements
select_lesser = tf.reduce_sum(
self.init_select_lesser * select_full_column_softmax,
1) * self.batch_question_number_mask #BS * max_elements
select_geq = tf.reduce_sum(
self.init_select_geq * select_full_column_softmax,
1) * self.batch_question_number_mask #BS * max_elements
select_leq = tf.reduce_sum(
self.init_select_leq * select_full_column_softmax,
1) * self.batch_question_number_mask #BS * max_elements
select_max = tf.reduce_sum(init_max * select_full_column_softmax,
1) #BS * max_elements
select_min = tf.reduce_sum(init_min * select_full_column_softmax,
1) #BS * max_elements
select_prev = tf.concat(1, [
tf.slice(select, [0, 1], [self.batch_size, self.max_elements - 1]),
tf.cast(tf.zeros([self.batch_size, 1]), self.data_type)
])
select_next = tf.concat(1, [
tf.cast(tf.zeros([self.batch_size, 1]), self.data_type), tf.slice(
select, [0, 0], [self.batch_size, self.max_elements - 1])
])
select_last_rs = self.compute_first_or_last(select, False)
select_first_rs = self.compute_first_or_last(select, True)
select_word_match = tf.reduce_sum(self.batch_exact_match *
select_full_column_softmax, 1)
select_group_by_max = tf.reduce_sum(self.batch_group_by_max *
select_full_column_softmax, 1)
length_content = 1
length_select = 13
length_print = 1
values = tf.concat(1, [count])
softmax_content = tf.slice(softmax, [0, 0],
[self.batch_size, length_content])
#compute scalar output
output = tf.reduce_sum(tf.mul(softmax_content, values), 1)
#compute lookup answer
softmax_print = tf.slice(softmax, [0, length_content + length_select],
[self.batch_size, length_print])
curr_print = select_full_column_softmax * tf.tile(
tf.expand_dims(select, 1),
[1, self.num_cols + self.num_word_cols, 1
]) #BS * max_cols * max_elements (conisders only column)
self.batch_lookup_answer = curr_print * tf.tile(
tf.expand_dims(softmax_print, 2),
[1, self.num_cols + self.num_word_cols, self.max_elements
]) #BS * max_cols * max_elements
self.batch_lookup_answer = self.batch_lookup_answer * self.select_full_mask
#compute row select
softmax_select = tf.slice(softmax, [0, length_content],
[self.batch_size, length_select])
select_lists = [
tf.expand_dims(select_prev, 1), tf.expand_dims(select_next, 1),
tf.expand_dims(select_first_rs, 1), tf.expand_dims(select_last_rs, 1),
tf.expand_dims(select_group_by_max, 1),
tf.expand_dims(select_greater, 1), tf.expand_dims(select_lesser, 1),
tf.expand_dims(select_geq, 1), tf.expand_dims(select_leq, 1),
tf.expand_dims(select_max, 1), tf.expand_dims(select_min, 1),
tf.expand_dims(select_word_match, 1),
tf.expand_dims(self.reset_select, 1)
]
select = tf.reduce_sum(
tf.tile(tf.expand_dims(softmax_select, 2), [1, 1, self.max_elements]) *
tf.concat(1, select_lists), 1)
select = select * self.select_whole_mask
return output, select
def one_pass(self, select, question_embedding, hidden_vectors, hprev,
prev_select_1, curr_pass):
#Performs one timestep which involves selecting an operation and a column
attention_vector = self.perform_attention(
hprev, hidden_vectors, self.question_length,
self.batch_question_attention_mask) #batch_size * embedding_dims
controller_vector = tf.nn.relu(
tf.matmul(hprev, self.params["controller_prev"]) + tf.matmul(
tf.concat(1, [question_embedding, attention_vector]), self.params[
"controller"]))
column_controller_vector = tf.nn.relu(
tf.matmul(hprev, self.params["column_controller_prev"]) + tf.matmul(
tf.concat(1, [question_embedding, attention_vector]), self.params[
"column_controller"]))
controller_vector = nn_utils.apply_dropout(
controller_vector, self.utility.FLAGS.dropout, self.mode)
self.operation_logits = tf.matmul(controller_vector,
tf.transpose(self.params_unit))
softmax = tf.nn.softmax(self.operation_logits)
soft_softmax = softmax
#compute column softmax: bs * max_columns
weighted_op_representation = tf.transpose(
tf.matmul(tf.transpose(self.params_unit), tf.transpose(softmax)))
column_controller_vector = tf.nn.relu(
tf.matmul(
tf.concat(1, [
column_controller_vector, weighted_op_representation
]), self.params["break_conditional"]))
full_column_softmax = self.compute_column_softmax(column_controller_vector,
curr_pass)
soft_column_softmax = full_column_softmax
if (self.mode == "test"):
full_column_softmax = self.make_hard_softmax(full_column_softmax)
softmax = self.make_hard_softmax(softmax)
output, select = self.perform_operations(softmax, full_column_softmax,
select, prev_select_1, curr_pass)
return output, select, softmax, soft_softmax, full_column_softmax, soft_column_softmax
def compute_lookup_error(self, val):
#computes lookup error.
cond = tf.equal(self.batch_print_answer, val)
inter = tf.select(
cond, self.init_print_error,
tf.tile(
tf.reshape(tf.constant(1e10, self.data_type), [1, 1, 1]), [
self.batch_size, self.utility.FLAGS.max_word_cols +
self.utility.FLAGS.max_number_cols,
self.utility.FLAGS.max_elements
]))
return tf.reduce_min(tf.reduce_min(inter, 1), 1) * tf.cast(
tf.greater(
tf.reduce_sum(tf.reduce_sum(tf.cast(cond, self.data_type), 1), 1),
0.0), self.data_type)
def soft_min(self, x, y):
return tf.maximum(-1.0 * (1 / (
self.utility.FLAGS.soft_min_value + 0.0)) * tf.log(
tf.exp(-self.utility.FLAGS.soft_min_value * x) + tf.exp(
-self.utility.FLAGS.soft_min_value * y)), tf.zeros_like(x))
def error_computation(self):
#computes the error of each example in a batch
math_error = 0.5 * tf.square(tf.sub(self.scalar_output, self.batch_answer))
#scale math error
math_error = math_error / self.rows
math_error = tf.minimum(math_error, self.utility.FLAGS.max_math_error *
tf.ones(tf.shape(math_error), self.data_type))
self.init_print_error = tf.select(
self.batch_gold_select, -1 * tf.log(self.batch_lookup_answer + 1e-300 +
self.invert_select_full_mask), -1 *
tf.log(1 - self.batch_lookup_answer)) * self.select_full_mask
print_error_1 = self.init_print_error * tf.cast(
tf.equal(self.batch_print_answer, 0.0), self.data_type)
print_error = tf.reduce_sum(tf.reduce_sum((print_error_1), 1), 1)
for val in range(1, 58):
print_error += self.compute_lookup_error(val + 0.0)
print_error = print_error * self.utility.FLAGS.print_cost / self.num_entries
if (self.mode == "train"):
error = tf.select(
tf.logical_and(
tf.not_equal(self.batch_answer, 0.0),
tf.not_equal(
tf.reduce_sum(tf.reduce_sum(self.batch_print_answer, 1), 1),
0.0)),
self.soft_min(math_error, print_error),
tf.select(
tf.not_equal(self.batch_answer, 0.0), math_error, print_error))
else:
error = tf.select(
tf.logical_and(
tf.equal(self.scalar_output, 0.0),
tf.equal(
tf.reduce_sum(tf.reduce_sum(self.batch_lookup_answer, 1), 1),
0.0)),
tf.ones_like(math_error),
tf.select(
tf.equal(self.scalar_output, 0.0), print_error, math_error))
return error
def batch_process(self):
#Computes loss and fraction of correct examples in a batch.
self.params_unit = nn_utils.apply_dropout(
self.params["unit"], self.utility.FLAGS.dropout, self.mode)
batch_size = self.batch_size
max_passes = self.max_passes
num_timesteps = 1
max_elements = self.max_elements
select = tf.cast(
tf.fill([self.batch_size, max_elements], 1.0), self.data_type)
hprev = tf.cast(
tf.fill([self.batch_size, self.embedding_dims], 0.0),
self.data_type) #running sum of the hidden states of the model
output = tf.cast(tf.fill([self.batch_size, 1], 0.0),
self.data_type) #output of the model
correct = tf.cast(
tf.fill([1], 0.0), self.data_type
) #to compute accuracy, returns number of correct examples for this batch
total_error = 0.0
prev_select_1 = tf.zeros_like(select)
self.create_summary_embeddings()
self.get_column_hidden_vectors()
#get question embedding
question_embedding, hidden_vectors = self.LSTM_question_embedding(
self.batch_question, self.question_length)
#compute arguments for comparison operation
greater_question_number, lesser_question_number, geq_question_number, leq_question_number = self.question_number_softmax(
hidden_vectors)
self.init_select_greater = tf.cast(
tf.greater(self.full_processed_column,
tf.expand_dims(greater_question_number, 2)), self.
data_type) * self.select_bad_number_mask #bs * max_cols * max_elements
self.init_select_lesser = tf.cast(
tf.less(self.full_processed_column,
tf.expand_dims(lesser_question_number, 2)), self.
data_type) * self.select_bad_number_mask #bs * max_cols * max_elements
self.init_select_geq = tf.cast(
tf.greater_equal(self.full_processed_column,
tf.expand_dims(geq_question_number, 2)), self.
data_type) * self.select_bad_number_mask #bs * max_cols * max_elements
self.init_select_leq = tf.cast(
tf.less_equal(self.full_processed_column,
tf.expand_dims(leq_question_number, 2)), self.
data_type) * self.select_bad_number_mask #bs * max_cols * max_elements
self.init_select_word_match = 0
if (self.utility.FLAGS.rnn_dropout > 0.0):
if (self.mode == "train"):
history_rnn_dropout_mask = tf.cast(
tf.random_uniform(
tf.shape(hprev), minval=0.0, maxval=1.0) <
self.utility.FLAGS.rnn_dropout,
self.data_type) / self.utility.FLAGS.rnn_dropout
else:
history_rnn_dropout_mask = tf.ones_like(hprev)
select = select * self.select_whole_mask
self.batch_log_prob = tf.zeros([self.batch_size], dtype=self.data_type)
#Perform max_passes and at each pass select operation and column
for curr_pass in range(max_passes):
print "step: ", curr_pass
output, select, softmax, soft_softmax, column_softmax, soft_column_softmax = self.one_pass(
select, question_embedding, hidden_vectors, hprev, prev_select_1,
curr_pass)
prev_select_1 = select
#compute input to history RNN
input_op = tf.transpose(
tf.matmul(
tf.transpose(self.params_unit), tf.transpose(
soft_softmax))) #weighted average of emebdding of operations
input_col = tf.reduce_sum(
tf.expand_dims(soft_column_softmax, 2) *
self.full_column_hidden_vectors, 1)
history_input = tf.concat(1, [input_op, input_col])
history_input = nn_utils.apply_dropout(
history_input, self.utility.FLAGS.dropout, self.mode)
hprev = self.history_recurrent_step(history_input, hprev)
if (self.utility.FLAGS.rnn_dropout > 0.0):
hprev = hprev * history_rnn_dropout_mask
self.scalar_output = output
error = self.error_computation()
cond = tf.less(error, 0.0001, name="cond")
correct_add = tf.select(
cond, tf.fill(tf.shape(cond), 1.0), tf.fill(tf.shape(cond), 0.0))
correct = tf.reduce_sum(correct_add)
error = error / batch_size
total_error = tf.reduce_sum(error)
total_correct = correct / batch_size
return total_error, total_correct
def compute_error(self):
#Sets mask variables and performs batch processing
self.batch_gold_select = self.batch_print_answer > 0.0
self.full_column_mask = tf.concat(
1, [self.batch_number_column_mask, self.batch_word_column_mask])
self.full_processed_column = tf.concat(
1,
[self.batch_processed_number_column, self.batch_processed_word_column])
self.full_processed_sorted_index_column = tf.concat(1, [
self.batch_processed_sorted_index_number_column,
self.batch_processed_sorted_index_word_column
])
self.select_bad_number_mask = tf.cast(
tf.logical_and(
tf.not_equal(self.full_processed_column,
self.utility.FLAGS.pad_int),
tf.not_equal(self.full_processed_column,
self.utility.FLAGS.bad_number_pre_process)),
self.data_type)
self.select_mask = tf.cast(
tf.logical_not(
tf.equal(self.batch_number_column, self.utility.FLAGS.pad_int)),
self.data_type)
self.select_word_mask = tf.cast(
tf.logical_not(
tf.equal(self.batch_word_column_entry_mask,
self.utility.dummy_token_id)), self.data_type)
self.select_full_mask = tf.concat(
1, [self.select_mask, self.select_word_mask])
self.select_whole_mask = tf.maximum(
tf.reshape(
tf.slice(self.select_mask, [0, 0, 0],
[self.batch_size, 1, self.max_elements]),
[self.batch_size, self.max_elements]),
tf.reshape(
tf.slice(self.select_word_mask, [0, 0, 0],
[self.batch_size, 1, self.max_elements]),
[self.batch_size, self.max_elements]))
self.invert_select_full_mask = tf.cast(
tf.concat(1, [
tf.equal(self.batch_number_column, self.utility.FLAGS.pad_int),
tf.equal(self.batch_word_column_entry_mask,
self.utility.dummy_token_id)
]), self.data_type)
self.batch_lookup_answer = tf.zeros(tf.shape(self.batch_gold_select))
self.reset_select = self.select_whole_mask
self.rows = tf.reduce_sum(self.select_whole_mask, 1)
self.num_entries = tf.reshape(
tf.reduce_sum(tf.reduce_sum(self.select_full_mask, 1), 1),
[self.batch_size])
self.final_error, self.final_correct = self.batch_process()
return self.final_error
def create_graph(self, params, global_step):
#Creates the graph to compute error, gradient computation and updates parameters
self.params = params
batch_size = self.batch_size
learning_rate = tf.cast(self.utility.FLAGS.learning_rate, self.data_type)
self.total_cost = self.compute_error()
optimize_params = self.params.values()
optimize_names = self.params.keys()
print "optimize params ", optimize_names
if (self.utility.FLAGS.l2_regularizer > 0.0):
reg_cost = 0.0
for ind_param in self.params.keys():
reg_cost += tf.nn.l2_loss(self.params[ind_param])
self.total_cost += self.utility.FLAGS.l2_regularizer * reg_cost
grads = tf.gradients(self.total_cost, optimize_params, name="gradients")
grad_norm = 0.0
for p, name in zip(grads, optimize_names):
print "grads: ", p, name
if isinstance(p, tf.IndexedSlices):
grad_norm += tf.reduce_sum(p.values * p.values)
elif not (p == None):
grad_norm += tf.reduce_sum(p * p)
grad_norm = tf.sqrt(grad_norm)
max_grad_norm = np.float32(self.utility.FLAGS.clip_gradients).astype(
self.utility.np_data_type[self.utility.FLAGS.data_type])
grad_scale = tf.minimum(
tf.cast(1.0, self.data_type), max_grad_norm / grad_norm)
clipped_grads = list()
for p in grads:
if isinstance(p, tf.IndexedSlices):
tmp = p.values * grad_scale
clipped_grads.append(tf.IndexedSlices(tmp, p.indices))
elif not (p == None):
clipped_grads.append(p * grad_scale)
else:
clipped_grads.append(p)
grads = clipped_grads
self.global_step = global_step
params_list = self.params.values()
params_list.append(self.global_step)
adam = tf.train.AdamOptimizer(
learning_rate,
epsilon=tf.cast(self.utility.FLAGS.eps, self.data_type),
use_locking=True)
self.step = adam.apply_gradients(zip(grads, optimize_params),
global_step=self.global_step)
self.init_op = tf.initialize_all_variables()
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of the Neural Programmer model described in https://openreview.net/pdf?id=ry2YOrcge
This file calls functions to load & pre-process data, construct the TF graph
and performs training or evaluation as specified by the flag evaluator_job
Author: aneelakantan (Arvind Neelakantan)
"""
import time
from random import Random
import numpy as np
import tensorflow as tf
import model
import wiki_data
import parameters
import data_utils
tf.flags.DEFINE_integer("train_steps", 100001, "Number of steps to train")
tf.flags.DEFINE_integer("eval_cycle", 500,
"Evaluate model at every eval_cycle steps")
tf.flags.DEFINE_integer("max_elements", 100,
"maximum rows that are considered for processing")
tf.flags.DEFINE_integer(
"max_number_cols", 15,
"maximum number columns that are considered for processing")
tf.flags.DEFINE_integer(
"max_word_cols", 25,
"maximum number columns that are considered for processing")
tf.flags.DEFINE_integer("question_length", 62, "maximum question length")
tf.flags.DEFINE_integer("max_entry_length", 1, "")
tf.flags.DEFINE_integer("max_passes", 4, "number of operation passes")
tf.flags.DEFINE_integer("embedding_dims", 256, "")
tf.flags.DEFINE_integer("batch_size", 20, "")
tf.flags.DEFINE_float("clip_gradients", 1.0, "")
tf.flags.DEFINE_float("eps", 1e-6, "")
tf.flags.DEFINE_float("param_init", 0.1, "")
tf.flags.DEFINE_float("learning_rate", 0.001, "")
tf.flags.DEFINE_float("l2_regularizer", 0.0001, "")
tf.flags.DEFINE_float("print_cost", 50.0,
"weighting factor in the objective function")
tf.flags.DEFINE_string("job_id", "temp", """job id""")
tf.flags.DEFINE_string("output_dir", "../model/",
"""output_dir""")
tf.flags.DEFINE_string("data_dir", "../data/",
"""data_dir""")
tf.flags.DEFINE_integer("write_every", 500, "wrtie every N")
tf.flags.DEFINE_integer("param_seed", 150, "")
tf.flags.DEFINE_integer("python_seed", 200, "")
tf.flags.DEFINE_float("dropout", 0.8, "dropout keep probability")
tf.flags.DEFINE_float("rnn_dropout", 0.9,
"dropout keep probability for rnn connections")
tf.flags.DEFINE_float("pad_int", -20000.0,
"number columns are padded with pad_int")
tf.flags.DEFINE_string("data_type", "double", "float or double")
tf.flags.DEFINE_float("word_dropout_prob", 0.9, "word dropout keep prob")
tf.flags.DEFINE_integer("word_cutoff", 10, "")
tf.flags.DEFINE_integer("vocab_size", 10800, "")
tf.flags.DEFINE_boolean("evaluator_job", False,
"wehther to run as trainer/evaluator")
tf.flags.DEFINE_float(
"bad_number_pre_process", -200000.0,
"number that is added to a corrupted table entry in a number column")
tf.flags.DEFINE_float("max_math_error", 3.0,
"max square loss error that is considered")
tf.flags.DEFINE_float("soft_min_value", 5.0, "")
FLAGS = tf.flags.FLAGS
class Utility:
#holds FLAGS and other variables that are used in different files
def __init__(self):
global FLAGS
self.FLAGS = FLAGS
self.unk_token = "UNK"
self.entry_match_token = "entry_match"
self.column_match_token = "column_match"
self.dummy_token = "dummy_token"
self.tf_data_type = {}
self.tf_data_type["double"] = tf.float64
self.tf_data_type["float"] = tf.float32
self.np_data_type = {}
self.np_data_type["double"] = np.float64
self.np_data_type["float"] = np.float32
self.operations_set = ["count"] + [
"prev", "next", "first_rs", "last_rs", "group_by_max", "greater",
"lesser", "geq", "leq", "max", "min", "word-match"
] + ["reset_select"] + ["print"]
self.word_ids = {}
self.reverse_word_ids = {}
self.word_count = {}
self.random = Random(FLAGS.python_seed)
def evaluate(sess, data, batch_size, graph, i):
#computes accuracy
num_examples = 0.0
gc = 0.0
for j in range(0, len(data) - batch_size + 1, batch_size):
[ct] = sess.run([graph.final_correct],
feed_dict=data_utils.generate_feed_dict(data, j, batch_size,
graph))
gc += ct * batch_size
num_examples += batch_size
print "dev set accuracy after ", i, " : ", gc / num_examples
print num_examples, len(data)
print "--------"
def Train(graph, utility, batch_size, train_data, sess, model_dir,
saver):
#performs training
curr = 0
train_set_loss = 0.0
utility.random.shuffle(train_data)
start = time.time()
for i in range(utility.FLAGS.train_steps):
curr_step = i
if (i > 0 and i % FLAGS.write_every == 0):
model_file = model_dir + "/model_" + str(i)
saver.save(sess, model_file)
if curr + batch_size >= len(train_data):
curr = 0
utility.random.shuffle(train_data)
step, cost_value = sess.run(
[graph.step, graph.total_cost],
feed_dict=data_utils.generate_feed_dict(
train_data, curr, batch_size, graph, train=True, utility=utility))
curr = curr + batch_size
train_set_loss += cost_value
if (i > 0 and i % FLAGS.eval_cycle == 0):
end = time.time()
time_taken = end - start
print "step ", i, " ", time_taken, " seconds "
start = end
print " printing train set loss: ", train_set_loss / utility.FLAGS.eval_cycle
train_set_loss = 0.0
def master(train_data, dev_data, utility):
#creates TF graph and calls trainer or evaluator
batch_size = utility.FLAGS.batch_size
model_dir = utility.FLAGS.output_dir + "/model" + utility.FLAGS.job_id + "/"
#create all paramters of the model
param_class = parameters.Parameters(utility)
params, global_step, init = param_class.parameters(utility)
key = "test" if (FLAGS.evaluator_job) else "train"
graph = model.Graph(utility, batch_size, utility.FLAGS.max_passes, mode=key)
graph.create_graph(params, global_step)
prev_dev_error = 0.0
final_loss = 0.0
final_accuracy = 0.0
#start session
with tf.Session() as sess:
sess.run(init.name)
sess.run(graph.init_op.name)
to_save = params.copy()
saver = tf.train.Saver(to_save, max_to_keep=500)
if (FLAGS.evaluator_job):
while True:
selected_models = {}
file_list = tf.gfile.ListDirectory(model_dir)
for model_file in file_list:
if ("checkpoint" in model_file or "index" in model_file or
"meta" in model_file):
continue
if ("data" in model_file):
model_file = model_file.split(".")[0]
model_step = int(
model_file.split("_")[len(model_file.split("_")) - 1])
selected_models[model_step] = model_file
file_list = sorted(selected_models.items(), key=lambda x: x[0])
if (len(file_list) > 0):
file_list = file_list[0:len(file_list) - 1]
print "list of models: ", file_list
for model_file in file_list:
model_file = model_file[1]
print "restoring: ", model_file
saver.restore(sess, model_dir + "/" + model_file)
model_step = int(
model_file.split("_")[len(model_file.split("_")) - 1])
print "evaluating on dev ", model_file, model_step
evaluate(sess, dev_data, batch_size, graph, model_step)
else:
ckpt = tf.train.get_checkpoint_state(model_dir)
print "model dir: ", model_dir
if (not (tf.gfile.IsDirectory(model_dir))):
print "create dir: ", model_dir
tf.gfile.MkDir(model_dir)
Train(graph, utility, batch_size, train_data, sess, model_dir,
saver)
def main(args):
utility = Utility()
train_name = "random-split-1-train.examples"
dev_name = "random-split-1-dev.examples"
test_name = "pristine-unseen-tables.examples"
#load data
dat = wiki_data.WikiQuestionGenerator(train_name, dev_name, test_name, FLAGS.data_dir)
train_data, dev_data, test_data = dat.load()
utility.words = []
utility.word_ids = {}
utility.reverse_word_ids = {}
#construct vocabulary
data_utils.construct_vocab(train_data, utility)
data_utils.construct_vocab(dev_data, utility, True)
data_utils.construct_vocab(test_data, utility, True)
data_utils.add_special_words(utility)
data_utils.perform_word_cutoff(utility)
#convert data to int format and pad the inputs
train_data = data_utils.complete_wiki_processing(train_data, utility, True)
dev_data = data_utils.complete_wiki_processing(dev_data, utility, False)
test_data = data_utils.complete_wiki_processing(test_data, utility, False)
print "# train examples ", len(train_data)
print "# dev examples ", len(dev_data)
print "# test examples ", len(test_data)
print "running open source"
#construct TF graph and train or evaluate
master(train_data, dev_data, utility)
if __name__ == "__main__":
tf.app.run()
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Author: aneelakantan (Arvind Neelakantan)
"""
import tensorflow as tf
def get_embedding(word, utility, params):
return tf.nn.embedding_lookup(params["word"], word)
def apply_dropout(x, dropout_rate, mode):
if (dropout_rate > 0.0):
if (mode == "train"):
x = tf.nn.dropout(x, dropout_rate)
else:
x = x
return x
def LSTMCell(x, mprev, cprev, key, params):
"""Create an LSTM cell.
Implements the equations in pg.2 from
"Long Short-Term Memory Based Recurrent Neural Network Architectures
For Large Vocabulary Speech Recognition",
Hasim Sak, Andrew Senior, Francoise Beaufays.
Args:
w: A dictionary of the weights and optional biases as returned
by LSTMParametersSplit().
x: Inputs to this cell.
mprev: m_{t-1}, the recurrent activations (same as the output)
from the previous cell.
cprev: c_{t-1}, the cell activations from the previous cell.
keep_prob: Keep probability on the input and the outputs of a cell.
Returns:
m: Outputs of this cell.
c: Cell Activations.
"""
i = tf.matmul(x, params[key + "_ix"]) + tf.matmul(mprev, params[key + "_im"])
i = tf.nn.bias_add(i, params[key + "_i"])
f = tf.matmul(x, params[key + "_fx"]) + tf.matmul(mprev, params[key + "_fm"])
f = tf.nn.bias_add(f, params[key + "_f"])
c = tf.matmul(x, params[key + "_cx"]) + tf.matmul(mprev, params[key + "_cm"])
c = tf.nn.bias_add(c, params[key + "_c"])
o = tf.matmul(x, params[key + "_ox"]) + tf.matmul(mprev, params[key + "_om"])
o = tf.nn.bias_add(o, params[key + "_o"])
i = tf.sigmoid(i, name="i_gate")
f = tf.sigmoid(f, name="f_gate")
o = tf.sigmoid(o, name="o_gate")
c = f * cprev + i * tf.tanh(c)
m = o * c
return m, c
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Author: aneelakantan (Arvind Neelakantan)
"""
import numpy as np
import tensorflow as tf
class Parameters:
def __init__(self, u):
self.utility = u
self.init_seed_counter = 0
self.word_init = {}
def parameters(self, utility):
params = {}
inits = []
embedding_dims = self.utility.FLAGS.embedding_dims
params["unit"] = tf.Variable(
self.RandomUniformInit([len(utility.operations_set), embedding_dims]))
params["word"] = tf.Variable(
self.RandomUniformInit([utility.FLAGS.vocab_size, embedding_dims]))
params["word_match_feature_column_name"] = tf.Variable(
self.RandomUniformInit([1]))
params["controller"] = tf.Variable(
self.RandomUniformInit([2 * embedding_dims, embedding_dims]))
params["column_controller"] = tf.Variable(
self.RandomUniformInit([2 * embedding_dims, embedding_dims]))
params["column_controller_prev"] = tf.Variable(
self.RandomUniformInit([embedding_dims, embedding_dims]))
params["controller_prev"] = tf.Variable(
self.RandomUniformInit([embedding_dims, embedding_dims]))
global_step = tf.Variable(1, name="global_step")
#weigths of question and history RNN (or LSTM)
key_list = ["question_lstm"]
for key in key_list:
# Weights going from inputs to nodes.
for wgts in ["ix", "fx", "cx", "ox"]:
params[key + "_" + wgts] = tf.Variable(
self.RandomUniformInit([embedding_dims, embedding_dims]))
# Weights going from nodes to nodes.
for wgts in ["im", "fm", "cm", "om"]:
params[key + "_" + wgts] = tf.Variable(
self.RandomUniformInit([embedding_dims, embedding_dims]))
#Biases for the gates and cell
for bias in ["i", "f", "c", "o"]:
if (bias == "f"):
print "forget gate bias"
params[key + "_" + bias] = tf.Variable(
tf.random_uniform([embedding_dims], 1.0, 1.1, self.utility.
tf_data_type[self.utility.FLAGS.data_type]))
else:
params[key + "_" + bias] = tf.Variable(
self.RandomUniformInit([embedding_dims]))
params["history_recurrent"] = tf.Variable(
self.RandomUniformInit([3 * embedding_dims, embedding_dims]))
params["history_recurrent_bias"] = tf.Variable(
self.RandomUniformInit([1, embedding_dims]))
params["break_conditional"] = tf.Variable(
self.RandomUniformInit([2 * embedding_dims, embedding_dims]))
init = tf.initialize_all_variables()
return params, global_step, init
def RandomUniformInit(self, shape):
"""Returns a RandomUniform Tensor between -param_init and param_init."""
param_seed = self.utility.FLAGS.param_seed
self.init_seed_counter += 1
return tf.random_uniform(
shape, -1.0 *
(np.float32(self.utility.FLAGS.param_init)
).astype(self.utility.np_data_type[self.utility.FLAGS.data_type]),
(np.float32(self.utility.FLAGS.param_init)
).astype(self.utility.np_data_type[self.utility.FLAGS.data_type]),
self.utility.tf_data_type[self.utility.FLAGS.data_type],
param_seed + self.init_seed_counter)
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Loads the WikiQuestions dataset.
An example consists of question, table. Additionally, we store the processed
columns which store the entries after performing number, date and other
preprocessing as done in the baseline.
columns, column names and processed columns are split into word and number
columns.
lookup answer (or matrix) is also split into number and word lookup matrix
Author: aneelakantan (Arvind Neelakantan)
"""
import math
import os
import re
import numpy as np
import unicodedata as ud
import tensorflow as tf
bad_number = -200000.0 #number that is added to a corrupted table entry in a number column
def is_nan_or_inf(number):
return math.isnan(number) or math.isinf(number)
def strip_accents(s):
u = unicode(s, "utf-8")
u_new = ''.join(c for c in ud.normalize('NFKD', u) if ud.category(c) != 'Mn')
return u_new.encode("utf-8")
def correct_unicode(string):
string = strip_accents(string)
string = re.sub("\xc2\xa0", " ", string).strip()
string = re.sub("\xe2\x80\x93", "-", string).strip()
#string = re.sub(ur'[\u0300-\u036F]', "", string)
string = re.sub("‚", ",", string)
string = re.sub("…", "...", string)
#string = re.sub("[·・]", ".", string)
string = re.sub("ˆ", "^", string)
string = re.sub("˜", "~", string)
string = re.sub("‹", "<", string)
string = re.sub("›", ">", string)
#string = re.sub("[‘’´`]", "'", string)
#string = re.sub("[“”«»]", "\"", string)
#string = re.sub("[•†‡]", "", string)
#string = re.sub("[‐‑–—]", "-", string)
string = re.sub(ur'[\u2E00-\uFFFF]', "", string)
string = re.sub("\\s+", " ", string).strip()
return string
def simple_normalize(string):
string = correct_unicode(string)
# Citations
string = re.sub("\[(nb ?)?\d+\]", "", string)
string = re.sub("\*+$", "", string)
# Year in parenthesis
string = re.sub("\(\d* ?-? ?\d*\)", "", string)
string = re.sub("^\"(.*)\"$", "", string)
return string
def full_normalize(string):
#print "an: ", string
string = simple_normalize(string)
# Remove trailing info in brackets
string = re.sub("\[[^\]]*\]", "", string)
# Remove most unicode characters in other languages
string = re.sub(ur'[\u007F-\uFFFF]', "", string.strip())
# Remove trailing info in parenthesis
string = re.sub("\([^)]*\)$", "", string.strip())
string = final_normalize(string)
# Get rid of question marks
string = re.sub("\?", "", string).strip()
# Get rid of trailing colons (usually occur in column titles)
string = re.sub("\:$", " ", string).strip()
# Get rid of slashes
string = re.sub(r"/", " ", string).strip()
string = re.sub(r"\\", " ", string).strip()
# Replace colon, slash, and dash with space
# Note: need better replacement for this when parsing time
string = re.sub(r"\:", " ", string).strip()
string = re.sub("/", " ", string).strip()
string = re.sub("-", " ", string).strip()
# Convert empty strings to UNK
# Important to do this last or near last
if not string:
string = "UNK"
return string
def final_normalize(string):
# Remove leading and trailing whitespace
string = re.sub("\\s+", " ", string).strip()
# Convert entirely to lowercase
string = string.lower()
# Get rid of strangely escaped newline characters
string = re.sub("\\\\n", " ", string).strip()
# Get rid of quotation marks
string = re.sub(r"\"", "", string).strip()
string = re.sub(r"\'", "", string).strip()
string = re.sub(r"`", "", string).strip()
# Get rid of *
string = re.sub("\*", "", string).strip()
return string
def is_number(x):
try:
f = float(x)
return not is_nan_or_inf(f)
except ValueError:
return False
except TypeError:
return False
class WikiExample(object):
def __init__(self, id, question, answer, table_key):
self.question_id = id
self.question = question
self.answer = answer
self.table_key = table_key
self.lookup_matrix = []
self.is_bad_example = False
self.is_word_lookup = False
self.is_ambiguous_word_lookup = False
self.is_number_lookup = False
self.is_number_calc = False
self.is_unknown_answer = False
class TableInfo(object):
def __init__(self, word_columns, word_column_names, word_column_indices,
number_columns, number_column_names, number_column_indices,
processed_word_columns, processed_number_columns, orig_columns):
self.word_columns = word_columns
self.word_column_names = word_column_names
self.word_column_indices = word_column_indices
self.number_columns = number_columns
self.number_column_names = number_column_names
self.number_column_indices = number_column_indices
self.processed_word_columns = processed_word_columns
self.processed_number_columns = processed_number_columns
self.orig_columns = orig_columns
class WikiQuestionLoader(object):
def __init__(self, data_name, root_folder):
self.root_folder = root_folder
self.data_folder = os.path.join(self.root_folder, "data")
self.examples = []
self.data_name = data_name
def num_questions(self):
return len(self.examples)
def load_qa(self):
data_source = os.path.join(self.data_folder, self.data_name)
f = tf.gfile.GFile(data_source, "r")
id_regex = re.compile("\(id ([^\)]*)\)")
for line in f:
id_match = id_regex.search(line)
id = id_match.group(1)
self.examples.append(id)
def load(self):
self.load_qa()
def is_date(word):
if (not (bool(re.search("[a-z0-9]", word, re.IGNORECASE)))):
return False
if (len(word) != 10):
return False
if (word[4] != "-"):
return False
if (word[7] != "-"):
return False
for i in range(len(word)):
if (not (word[i] == "X" or word[i] == "x" or word[i] == "-" or re.search(
"[0-9]", word[i]))):
return False
return True
class WikiQuestionGenerator(object):
def __init__(self, train_name, dev_name, test_name, root_folder):
self.train_name = train_name
self.dev_name = dev_name
self.test_name = test_name
self.train_loader = WikiQuestionLoader(train_name, root_folder)
self.dev_loader = WikiQuestionLoader(dev_name, root_folder)
self.test_loader = WikiQuestionLoader(test_name, root_folder)
self.bad_examples = 0
self.root_folder = root_folder
self.data_folder = os.path.join(self.root_folder, "annotated/data")
self.annotated_examples = {}
self.annotated_tables = {}
self.annotated_word_reject = {}
self.annotated_word_reject["-lrb-"] = 1
self.annotated_word_reject["-rrb-"] = 1
self.annotated_word_reject["UNK"] = 1
def is_money(self, word):
if (not (bool(re.search("[a-z0-9]", word, re.IGNORECASE)))):
return False
for i in range(len(word)):
if (not (word[i] == "E" or word[i] == "." or re.search("[0-9]",
word[i]))):
return False
return True
def remove_consecutive(self, ner_tags, ner_values):
for i in range(len(ner_tags)):
if ((ner_tags[i] == "NUMBER" or ner_tags[i] == "MONEY" or
ner_tags[i] == "PERCENT" or ner_tags[i] == "DATE") and
i + 1 < len(ner_tags) and ner_tags[i] == ner_tags[i + 1] and
ner_values[i] == ner_values[i + 1] and ner_values[i] != ""):
word = ner_values[i]
word = word.replace(">", "").replace("<", "").replace("=", "").replace(
"%", "").replace("~", "").replace("$", "").replace("£", "").replace(
"€", "")
if (re.search("[A-Z]", word) and not (is_date(word)) and not (
self.is_money(word))):
ner_values[i] = "A"
else:
ner_values[i] = ","
return ner_tags, ner_values
def pre_process_sentence(self, tokens, ner_tags, ner_values):
sentence = []
tokens = tokens.split("|")
ner_tags = ner_tags.split("|")
ner_values = ner_values.split("|")
ner_tags, ner_values = self.remove_consecutive(ner_tags, ner_values)
#print "old: ", tokens
for i in range(len(tokens)):
word = tokens[i]
if (ner_values[i] != "" and
(ner_tags[i] == "NUMBER" or ner_tags[i] == "MONEY" or
ner_tags[i] == "PERCENT" or ner_tags[i] == "DATE")):
word = ner_values[i]
word = word.replace(">", "").replace("<", "").replace("=", "").replace(
"%", "").replace("~", "").replace("$", "").replace("£", "").replace(
"€", "")
if (re.search("[A-Z]", word) and not (is_date(word)) and not (
self.is_money(word))):
word = tokens[i]
if (is_number(ner_values[i])):
word = float(ner_values[i])
elif (is_number(word)):
word = float(word)
if (tokens[i] == "score"):
word = "score"
if (is_number(word)):
word = float(word)
if (not (self.annotated_word_reject.has_key(word))):
if (is_number(word) or is_date(word) or self.is_money(word)):
sentence.append(word)
else:
word = full_normalize(word)
if (not (self.annotated_word_reject.has_key(word)) and
bool(re.search("[a-z0-9]", word, re.IGNORECASE))):
m = re.search(",", word)
sentence.append(word.replace(",", ""))
if (len(sentence) == 0):
sentence.append("UNK")
return sentence
def load_annotated_data(self, in_file):
self.annotated_examples = {}
self.annotated_tables = {}
f = tf.gfile.GFile(in_file, "r")
counter = 0
for line in f:
if (counter > 0):
line = line.strip()
(question_id, utterance, context, target_value, tokens, lemma_tokens,
pos_tags, ner_tags, ner_values, target_canon) = line.split("\t")
question = self.pre_process_sentence(tokens, ner_tags, ner_values)
target_canon = target_canon.split("|")
self.annotated_examples[question_id] = WikiExample(
question_id, question, target_canon, context)
self.annotated_tables[context] = []
counter += 1
print "Annotated examples loaded ", len(self.annotated_examples)
f.close()
def is_number_column(self, a):
for w in a:
if (len(w) != 1):
return False
if (not (is_number(w[0]))):
return False
return True
def convert_table(self, table):
answer = []
for i in range(len(table)):
temp = []
for j in range(len(table[i])):
temp.append(" ".join([str(w) for w in table[i][j]]))
answer.append(temp)
return answer
def load_annotated_tables(self):
for table in self.annotated_tables.keys():
annotated_table = table.replace("csv", "annotated")
orig_columns = []
processed_columns = []
f = tf.gfile.GFile(os.path.join(self.root_folder, annotated_table), "r")
counter = 0
for line in f:
if (counter > 0):
line = line.strip()
line = line + "\t" * (13 - len(line.split("\t")))
(row, col, read_id, content, tokens, lemma_tokens, pos_tags, ner_tags,
ner_values, number, date, num2, read_list) = line.split("\t")
counter += 1
f.close()
max_row = int(row)
max_col = int(col)
for i in range(max_col + 1):
orig_columns.append([])
processed_columns.append([])
for j in range(max_row + 1):
orig_columns[i].append(bad_number)
processed_columns[i].append(bad_number)
#print orig_columns
f = tf.gfile.GFile(os.path.join(self.root_folder, annotated_table), "r")
counter = 0
column_names = []
for line in f:
if (counter > 0):
line = line.strip()
line = line + "\t" * (13 - len(line.split("\t")))
(row, col, read_id, content, tokens, lemma_tokens, pos_tags, ner_tags,
ner_values, number, date, num2, read_list) = line.split("\t")
entry = self.pre_process_sentence(tokens, ner_tags, ner_values)
if (row == "-1"):
column_names.append(entry)
else:
orig_columns[int(col)][int(row)] = entry
if (len(entry) == 1 and is_number(entry[0])):
processed_columns[int(col)][int(row)] = float(entry[0])
else:
for single_entry in entry:
if (is_number(single_entry)):
processed_columns[int(col)][int(row)] = float(single_entry)
break
nt = ner_tags.split("|")
nv = ner_values.split("|")
for i_entry in range(len(tokens.split("|"))):
if (nt[i_entry] == "DATE" and
is_number(nv[i_entry].replace("-", "").replace("X", ""))):
processed_columns[int(col)][int(row)] = float(nv[
i_entry].replace("-", "").replace("X", ""))
#processed_columns[int(col)][int(row)] = float(nv[i_entry])
if (len(entry) == 1 and (is_number(entry[0]) or is_date(entry[0]) or
self.is_money(entry[0]))):
if (len(entry) == 1 and not (is_number(entry[0])) and
is_date(entry[0])):
entry[0] = entry[0].replace("X", "x")
counter += 1
word_columns = []
processed_word_columns = []
word_column_names = []
word_column_indices = []
number_columns = []
processed_number_columns = []
number_column_names = []
number_column_indices = []
for i in range(max_col + 1):
if (self.is_number_column(orig_columns[i])):
number_column_indices.append(i)
number_column_names.append(column_names[i])
temp = []
for w in orig_columns[i]:
if (is_number(w[0])):
temp.append(w[0])
number_columns.append(temp)
processed_number_columns.append(processed_columns[i])
else:
word_column_indices.append(i)
word_column_names.append(column_names[i])
word_columns.append(orig_columns[i])
processed_word_columns.append(processed_columns[i])
table_info = TableInfo(
word_columns, word_column_names, word_column_indices, number_columns,
number_column_names, number_column_indices, processed_word_columns,
processed_number_columns, orig_columns)
self.annotated_tables[table] = table_info
f.close()
def answer_classification(self):
lookup_questions = 0
number_lookup_questions = 0
word_lookup_questions = 0
ambiguous_lookup_questions = 0
number_questions = 0
bad_questions = 0
ice_bad_questions = 0
tot = 0
got = 0
ice = {}
with tf.gfile.GFile(
self.root_folder + "/arvind-with-norms-2.tsv", mode="r") as f:
lines = f.readlines()
for line in lines:
line = line.strip()
if (not (self.annotated_examples.has_key(line.split("\t")[0]))):
continue
if (len(line.split("\t")) == 4):
line = line + "\t" * (5 - len(line.split("\t")))
if (not (is_number(line.split("\t")[2]))):
ice_bad_questions += 1
(example_id, ans_index, ans_raw, process_answer,
matched_cells) = line.split("\t")
if (ice.has_key(example_id)):
ice[example_id].append(line.split("\t"))
else:
ice[example_id] = [line.split("\t")]
for q_id in self.annotated_examples.keys():
tot += 1
example = self.annotated_examples[q_id]
table_info = self.annotated_tables[example.table_key]
# Figure out if the answer is numerical or lookup
n_cols = len(table_info.orig_columns)
n_rows = len(table_info.orig_columns[0])
example.lookup_matrix = np.zeros((n_rows, n_cols))
exact_matches = {}
for (example_id, ans_index, ans_raw, process_answer,
matched_cells) in ice[q_id]:
for match_cell in matched_cells.split("|"):
if (len(match_cell.split(",")) == 2):
(row, col) = match_cell.split(",")
row = int(row)
col = int(col)
if (row >= 0):
exact_matches[ans_index] = 1
answer_is_in_table = len(exact_matches) == len(example.answer)
if (answer_is_in_table):
for (example_id, ans_index, ans_raw, process_answer,
matched_cells) in ice[q_id]:
for match_cell in matched_cells.split("|"):
if (len(match_cell.split(",")) == 2):
(row, col) = match_cell.split(",")
row = int(row)
col = int(col)
example.lookup_matrix[row, col] = float(ans_index) + 1.0
example.lookup_number_answer = 0.0
if (answer_is_in_table):
lookup_questions += 1
if len(example.answer) == 1 and is_number(example.answer[0]):
example.number_answer = float(example.answer[0])
number_lookup_questions += 1
example.is_number_lookup = True
else:
#print "word lookup"
example.calc_answer = example.number_answer = 0.0
word_lookup_questions += 1
example.is_word_lookup = True
else:
if (len(example.answer) == 1 and is_number(example.answer[0])):
example.number_answer = example.answer[0]
example.is_number_calc = True
else:
bad_questions += 1
example.is_bad_example = True
example.is_unknown_answer = True
example.is_lookup = example.is_word_lookup or example.is_number_lookup
if not example.is_word_lookup and not example.is_bad_example:
number_questions += 1
example.calc_answer = example.answer[0]
example.lookup_number_answer = example.calc_answer
# Split up the lookup matrix into word part and number part
number_column_indices = table_info.number_column_indices
word_column_indices = table_info.word_column_indices
example.word_columns = table_info.word_columns
example.number_columns = table_info.number_columns
example.word_column_names = table_info.word_column_names
example.processed_number_columns = table_info.processed_number_columns
example.processed_word_columns = table_info.processed_word_columns
example.number_column_names = table_info.number_column_names
example.number_lookup_matrix = example.lookup_matrix[:,
number_column_indices]
example.word_lookup_matrix = example.lookup_matrix[:, word_column_indices]
def load(self):
train_data = []
dev_data = []
test_data = []
self.load_annotated_data(
os.path.join(self.data_folder, "training.annotated"))
self.load_annotated_tables()
self.answer_classification()
self.train_loader.load()
self.dev_loader.load()
for i in range(self.train_loader.num_questions()):
example = self.train_loader.examples[i]
example = self.annotated_examples[example]
train_data.append(example)
for i in range(self.dev_loader.num_questions()):
example = self.dev_loader.examples[i]
dev_data.append(self.annotated_examples[example])
self.load_annotated_data(
os.path.join(self.data_folder, "pristine-unseen-tables.annotated"))
self.load_annotated_tables()
self.answer_classification()
self.test_loader.load()
for i in range(self.test_loader.num_questions()):
example = self.test_loader.examples[i]
test_data.append(self.annotated_examples[example])
return train_data, dev_data, test_data
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment