Commit 480f0ebf authored by Yuyu Zhang's avatar Yuyu Zhang
Browse files

commit code

parent 38f008a9
......@@ -33,9 +33,11 @@ The code in this repository is based on the original
## Data
1. Download the [MetaQA dataset](https://goo.gl/f3AmcY). Read the documents
there for dataset details.
2. Put the MetaQA folder in the root directory of this repository.
1. Download the [MetaQA dataset](https://goo.gl/f3AmcY). Click the button
`MetaQA` and then click `Download` in the drop-down list. Extract the zip
file after downloading completed. Read the documents there for dataset
details.
2. Move the `MetaQA` folder to the root directory of this repository.
## How to use this code
......
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import argparse
import os
def str2bool(v):
return v.lower() in ('true', '1')
def add_argument_group(name):
arg = parser.add_argument_group(name)
arg_lists.append(arg)
return arg
def get_config():
config, unparsed = parser.parse_known_args()
return config, unparsed
arg_lists = []
parser = argparse.ArgumentParser()
work_dir = os.path.abspath(os.path.join(__file__, '../../'))
net_arg = add_argument_group('Network')
net_arg.add_argument('--lstm_dim', type=int, default=128)
net_arg.add_argument('--num_layers', type=int, default=1)
net_arg.add_argument('--embed_dim_txt', type=int, default=128)
net_arg.add_argument('--embed_dim_nmn', type=int, default=128)
net_arg.add_argument(
'--T_encoder', type=int, default=0) # will be updated when reading data
net_arg.add_argument('--T_decoder', type=int, default=5)
train_arg = add_argument_group('Training')
train_arg.add_argument('--train_tag', type=str, default='n2nmn')
train_arg.add_argument('--batch_size', type=int, default=128)
train_arg.add_argument('--max_iter', type=int, default=1000000)
train_arg.add_argument('--weight_decay', type=float, default=1e-5)
train_arg.add_argument('--baseline_decay', type=float, default=0.99)
train_arg.add_argument('--max_grad_norm', type=float, default=10)
train_arg.add_argument('--random_seed', type=int, default=123)
data_arg = add_argument_group('Data')
data_path = work_dir + '/MetaQA/'
data_arg.add_argument('--KB_file', type=str, default=data_path + 'kb.txt')
data_arg.add_argument(
'--data_dir', type=str, default=data_path + '1-hop/vanilla/')
data_arg.add_argument('--train_data_file', type=str, default='qa_train.txt')
data_arg.add_argument('--dev_data_file', type=str, default='qa_dev.txt')
data_arg.add_argument('--test_data_file', type=str, default='qa_test.txt')
exp_arg = add_argument_group('Experiment')
exp_path = work_dir + '/exp_1_hop/'
exp_arg.add_argument('--exp_dir', type=str, default=exp_path)
log_arg = add_argument_group('Log')
log_arg.add_argument('--log_dir', type=str, default='logs')
log_arg.add_argument('--log_interval', type=int, default=1000)
log_arg.add_argument('--num_log_samples', type=int, default=3)
log_arg.add_argument(
'--log_level', type=str, default='INFO', choices=['INFO', 'DEBUG', 'WARN'])
io_arg = add_argument_group('IO')
io_arg.add_argument('--model_dir', type=str, default='model')
io_arg.add_argument('--snapshot_interval', type=int, default=1000)
io_arg.add_argument('--output_dir', type=str, default='output')
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
import sys
sys.path.append(os.path.abspath(os.path.join(__file__, '../../')))
import numpy as np
import tensorflow as tf
from config import get_config
from model_n2nmn.assembler import Assembler
from model_n2nmn.model import Model
from util.data_reader import DataReader
from util.data_reader import SampleBuilder
from util.misc import prepare_dirs_and_logger
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_string('snapshot_name', '00001000', 'snapshot file name')
def main(_):
config = prepare_dirs_and_logger(config_raw)
rng = np.random.RandomState(config.random_seed)
tf.set_random_seed(config.random_seed)
config.rng = rng
config.module_names = ['_key_find', '_key_filter', '_val_desc', '<eos>']
config.gt_layout_tokens = ['_key_find', '_key_filter', '_val_desc', '<eos>']
assembler = Assembler(config)
sample_builder = SampleBuilder(config)
config = sample_builder.config # update T_encoder according to data
data_test = sample_builder.data_all['test']
data_reader_test = DataReader(
config, data_test, assembler, shuffle=False, one_pass=True)
num_vocab_txt = len(sample_builder.dict_all)
num_vocab_nmn = len(assembler.module_names)
num_choices = len(sample_builder.dict_all)
# Network inputs
text_seq_batch = tf.placeholder(tf.int32, [None, None])
seq_len_batch = tf.placeholder(tf.int32, [None])
# The model
model = Model(
config,
sample_builder.kb,
text_seq_batch,
seq_len_batch,
num_vocab_txt=num_vocab_txt,
num_vocab_nmn=num_vocab_nmn,
EOS_idx=assembler.EOS_idx,
num_choices=num_choices,
decoder_sampling=False)
compiler = model.compiler
scores = model.scores
sess = tf.Session()
sess.run(tf.global_variables_initializer())
snapshot_file = os.path.join(config.model_dir, FLAGS.snapshot_name)
tf.logging.info('Snapshot file: %s' % snapshot_file)
snapshot_saver = tf.train.Saver()
snapshot_saver.restore(sess, snapshot_file)
# Evaluation metrics
num_questions = len(data_test.Y)
tf.logging.info('# of test questions: %d' % num_questions)
answer_correct = 0
layout_correct = 0
layout_valid = 0
for batch in data_reader_test.batches():
# set up input and output tensors
h = sess.partial_run_setup(
fetches=[model.predicted_tokens, scores],
feeds=[text_seq_batch, seq_len_batch, compiler.loom_input_tensor])
# Part 1: Generate module layout
tokens = sess.partial_run(
h,
fetches=model.predicted_tokens,
feed_dict={
text_seq_batch: batch['input_seq_batch'],
seq_len_batch: batch['seq_len_batch']
})
# Compute accuracy of the predicted layout
gt_tokens = batch['gt_layout_batch']
layout_correct += np.sum(
np.all(
np.logical_or(tokens == gt_tokens, gt_tokens == assembler.EOS_idx),
axis=0))
# Assemble the layout tokens into network structure
expr_list, expr_validity_array = assembler.assemble(tokens)
layout_valid += np.sum(expr_validity_array)
labels = batch['ans_label_batch']
# Build TensorFlow Fold input for NMN
expr_feed = compiler.build_feed_dict(expr_list)
# Part 2: Run NMN and learning steps
scores_val = sess.partial_run(h, scores, feed_dict=expr_feed)
# Compute accuracy
predictions = np.argmax(scores_val, axis=1)
answer_correct += np.sum(
np.logical_and(expr_validity_array, predictions == labels))
answer_accuracy = answer_correct * 1.0 / num_questions
layout_accuracy = layout_correct * 1.0 / num_questions
layout_validity = layout_valid * 1.0 / num_questions
tf.logging.info('test answer accuracy = %f, '
'test layout accuracy = %f, '
'test layout validity = %f' %
(answer_accuracy, layout_accuracy, layout_validity))
if __name__ == '__main__':
config_raw, unparsed = get_config()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
import sys
sys.path.append(os.path.abspath(os.path.join(__file__, '../../')))
import numpy as np
import tensorflow as tf
from config import get_config
from model_n2nmn.assembler import Assembler
from model_n2nmn.model import Model
from util.data_reader import DataReader
from util.data_reader import SampleBuilder
from util.misc import prepare_dirs_and_logger
from util.misc import save_config
from util.misc import show_all_variables
def main(_):
config = prepare_dirs_and_logger(config_raw)
save_config(config)
rng = np.random.RandomState(config.random_seed)
tf.set_random_seed(config.random_seed)
config.rng = rng
config.module_names = ['_key_find', '_key_filter', '_val_desc', '<eos>']
config.gt_layout_tokens = ['_key_find', '_key_filter', '_val_desc', '<eos>']
assembler = Assembler(config)
sample_builder = SampleBuilder(config)
config = sample_builder.config # update T_encoder according to data
data_train = sample_builder.data_all['train']
data_reader_train = DataReader(
config, data_train, assembler, shuffle=True, one_pass=False)
num_vocab_txt = len(sample_builder.dict_all)
num_vocab_nmn = len(assembler.module_names)
num_choices = len(sample_builder.dict_all)
# Network inputs
text_seq_batch = tf.placeholder(tf.int32, [None, None])
seq_len_batch = tf.placeholder(tf.int32, [None])
ans_label_batch = tf.placeholder(tf.int32, [None])
use_gt_layout = tf.constant(True, dtype=tf.bool)
gt_layout_batch = tf.placeholder(tf.int32, [None, None])
# The model for training
model = Model(
config,
sample_builder.kb,
text_seq_batch,
seq_len_batch,
num_vocab_txt=num_vocab_txt,
num_vocab_nmn=num_vocab_nmn,
EOS_idx=assembler.EOS_idx,
num_choices=num_choices,
decoder_sampling=True,
use_gt_layout=use_gt_layout,
gt_layout_batch=gt_layout_batch)
compiler = model.compiler
scores = model.scores
log_seq_prob = model.log_seq_prob
# Loss function
softmax_loss_per_sample = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=scores, labels=ans_label_batch)
# The final per-sample loss, which is loss for valid expr
# and invalid_expr_loss for invalid expr
final_loss_per_sample = softmax_loss_per_sample # All exprs are valid
avg_sample_loss = tf.reduce_mean(final_loss_per_sample)
seq_likelihood_loss = tf.reduce_mean(-log_seq_prob)
total_training_loss = seq_likelihood_loss + avg_sample_loss
total_loss = total_training_loss + config.weight_decay * model.l2_reg
# Train with Adam optimizer
solver = tf.train.AdamOptimizer()
gradients = solver.compute_gradients(total_loss)
# Clip gradient by L2 norm
gradients = [(tf.clip_by_norm(g, config.max_grad_norm), v)
for g, v in gradients]
solver_op = solver.apply_gradients(gradients)
# Training operation
with tf.control_dependencies([solver_op]):
train_step = tf.constant(0)
# Write summary to TensorBoard
log_writer = tf.summary.FileWriter(config.log_dir, tf.get_default_graph())
loss_ph = tf.placeholder(tf.float32, [])
entropy_ph = tf.placeholder(tf.float32, [])
accuracy_ph = tf.placeholder(tf.float32, [])
summary_train = [
tf.summary.scalar('avg_sample_loss', loss_ph),
tf.summary.scalar('entropy', entropy_ph),
tf.summary.scalar('avg_accuracy', accuracy_ph)
]
log_step_train = tf.summary.merge(summary_train)
# Training
sess = tf.Session()
sess.run(tf.global_variables_initializer())
snapshot_saver = tf.train.Saver(max_to_keep=None) # keep all snapshots
show_all_variables()
avg_accuracy = 0
accuracy_decay = 0.99
for n_iter, batch in enumerate(data_reader_train.batches()):
if n_iter >= config.max_iter:
break
# set up input and output tensors
h = sess.partial_run_setup(
fetches=[
model.predicted_tokens, model.entropy_reg, scores, avg_sample_loss,
train_step
],
feeds=[
text_seq_batch, seq_len_batch, gt_layout_batch,
compiler.loom_input_tensor, ans_label_batch
])
# Part 1: Generate module layout
tokens, entropy_reg_val = sess.partial_run(
h,
fetches=(model.predicted_tokens, model.entropy_reg),
feed_dict={
text_seq_batch: batch['input_seq_batch'],
seq_len_batch: batch['seq_len_batch'],
gt_layout_batch: batch['gt_layout_batch']
})
# Assemble the layout tokens into network structure
expr_list, expr_validity_array = assembler.assemble(tokens)
# all exprs should be valid (since they are ground-truth)
assert np.all(expr_validity_array)
labels = batch['ans_label_batch']
# Build TensorFlow Fold input for NMN
expr_feed = compiler.build_feed_dict(expr_list)
expr_feed[ans_label_batch] = labels
# Part 2: Run NMN and learning steps
scores_val, avg_sample_loss_val, _ = sess.partial_run(
h, fetches=(scores, avg_sample_loss, train_step), feed_dict=expr_feed)
# Compute accuracy
predictions = np.argmax(scores_val, axis=1)
accuracy = np.mean(
np.logical_and(expr_validity_array, predictions == labels))
avg_accuracy += (1 - accuracy_decay) * (accuracy - avg_accuracy)
# Add to TensorBoard summary
if (n_iter + 1) % config.log_interval == 0:
tf.logging.info('iter = %d\n\t'
'loss = %f, accuracy (cur) = %f, '
'accuracy (avg) = %f, entropy = %f' %
(n_iter + 1, avg_sample_loss_val, accuracy, avg_accuracy,
-entropy_reg_val))
summary = sess.run(
fetches=log_step_train,
feed_dict={
loss_ph: avg_sample_loss_val,
entropy_ph: -entropy_reg_val,
accuracy_ph: avg_accuracy
})
log_writer.add_summary(summary, n_iter + 1)
# Save snapshot
if (n_iter + 1) % config.snapshot_interval == 0:
snapshot_file = os.path.join(config.model_dir, '%08d' % (n_iter + 1))
snapshot_saver.save(sess, snapshot_file, write_meta_graph=False)
tf.logging.info('Snapshot saved to %s' % snapshot_file)
tf.logging.info('Run finished.')
if __name__ == '__main__':
config_raw, unparsed = get_config()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
# the number of attention input to each module
_module_input_num = {
'_key_find': 0,
'_key_filter': 1,
'_val_desc': 1}
_module_output_type = {
'_key_find': 'att',
'_key_filter': 'att',
'_val_desc': 'ans'
}
INVALID_EXPR = 'INVALID_EXPR'
class Assembler:
def __init__(self, config):
# read the module list, and record the index of each module and <eos>
self.module_names = config.module_names
# find the index of <eos>
for n_s in range(len(self.module_names)):
if self.module_names[n_s] == '<eos>':
self.EOS_idx = n_s
break
# build a dictionary from module name to token index
self.name2idx_dict = {
name: n_s
for n_s, name in enumerate(self.module_names)
}
def module_list2tokens(self, module_list, max_len=None):
layout_tokens = [self.name2idx_dict[name] for name in module_list]
if max_len is not None:
if len(module_list) >= max_len:
raise ValueError('Not enough time steps to add <eos>')
layout_tokens += [self.EOS_idx] * (max_len - len(module_list))
return layout_tokens
def _layout_tokens2str(self, layout_tokens):
return ' '.join([self.module_names[idx] for idx in layout_tokens])
def _invalid_expr(self, layout_tokens, error_str):
return {
'module': INVALID_EXPR,
'expr_str': self._layout_tokens2str(layout_tokens),
'error': error_str
}
def _assemble_layout_tokens(self, layout_tokens, batch_idx):
# Every module takes a time_idx as the index from LSTM hidden states
# (even if it doesn't need it, like _and), and different arity of
# attention inputs. The output type can be either attention or answer
#
# The final assembled expression for each instance is as follows:
# expr_type :=
# {'module': '_find', 'output_type': 'att', 'time_idx': idx}
# | {'module': '_relocate', 'output_type': 'att', 'time_idx': idx,
# 'inputs_0': <expr_type>}
# | {'module': '_and', 'output_type': 'att', 'time_idx': idx,
# 'inputs_0': <expr_type>, 'inputs_1': <expr_type>)}
# | {'module': '_describe', 'output_type': 'ans', 'time_idx': idx,
# 'inputs_0': <expr_type>}
# | {'module': INVALID_EXPR, 'expr_str': '...', 'error': '...',
# 'assembly_loss': <float32>} (for invalid expressions)
#
# A valid layout must contain <eos>. Assembly fails if it doesn't.
if not np.any(layout_tokens == self.EOS_idx):
return self._invalid_expr(layout_tokens, 'cannot find <eos>')
# Decoding Reverse Polish Notation with a stack
decoding_stack = []
for t in range(len(layout_tokens)):
# decode a module/operation
module_idx = layout_tokens[t]
if module_idx == self.EOS_idx:
break
module_name = self.module_names[module_idx]
expr = {
'module': module_name,
'output_type': _module_output_type[module_name],
'time_idx': t,
'batch_idx': batch_idx
}
input_num = _module_input_num[module_name]
# Check if there are enough input in the stack
if len(decoding_stack) < input_num:
# Invalid expression. Not enough input.
return self._invalid_expr(layout_tokens,
'not enough input for ' + module_name)
# Get the input from stack
for n_input in range(input_num - 1, -1, -1):
stack_top = decoding_stack.pop()
if stack_top['output_type'] != 'att':
# Invalid expression. Input must be attention
return self._invalid_expr(layout_tokens,
'input incompatible for ' + module_name)
expr['input_%d' % n_input] = stack_top
decoding_stack.append(expr)
# After decoding the reverse polish expression, there should be exactly
# one expression in the stack
if len(decoding_stack) != 1:
return self._invalid_expr(
layout_tokens,
'final stack size not equal to 1 (%d remains)' % len(decoding_stack))
result = decoding_stack[0]
# The result type should be answer, not attention
if result['output_type'] != 'ans':
return self._invalid_expr(layout_tokens,
'result type must be ans, not att')
return result
def assemble(self, layout_tokens_batch):
# layout_tokens_batch is a numpy array with shape [max_dec_len, batch_size],
# containing module tokens and <eos>, in Reverse Polish Notation.
_, batch_size = layout_tokens_batch.shape
expr_list = [
self._assemble_layout_tokens(layout_tokens_batch[:, batch_i], batch_i)
for batch_i in range(batch_size)
]
expr_validity = np.array(
[expr['module'] != INVALID_EXPR for expr in expr_list], np.bool)
return expr_list, expr_validity
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import tensorflow as tf
import tensorflow_fold as td
from model_n2nmn import netgen_att
from model_n2nmn import assembler
from model_n2nmn.modules import Modules
class Model:
def __init__(self,
config,
kb,
text_seq_batch,
seq_length_batch,
num_vocab_txt,
num_vocab_nmn,
EOS_idx,
num_choices,
decoder_sampling,
use_gt_layout=None,
gt_layout_batch=None,
scope='neural_module_network',
reuse=None):
with tf.variable_scope(scope, reuse=reuse):
# Part 1: Seq2seq RNN to generate module layout tokens
embedding_mat = tf.get_variable(
'embedding_mat', [num_vocab_txt, config.embed_dim_txt],
initializer=tf.contrib.layers.xavier_initializer())
with tf.variable_scope('layout_generation'):
att_seq2seq = netgen_att.AttentionSeq2Seq(
config, text_seq_batch, seq_length_batch, num_vocab_txt,
num_vocab_nmn, EOS_idx, decoder_sampling, embedding_mat,
use_gt_layout, gt_layout_batch)
self.att_seq2seq = att_seq2seq
predicted_tokens = att_seq2seq.predicted_tokens
token_probs = att_seq2seq.token_probs
word_vecs = att_seq2seq.word_vecs
neg_entropy = att_seq2seq.neg_entropy
self.atts = att_seq2seq.atts
self.predicted_tokens = predicted_tokens
self.token_probs = token_probs
self.word_vecs = word_vecs
self.neg_entropy = neg_entropy
# log probability of each generated sequence
self.log_seq_prob = tf.reduce_sum(tf.log(token_probs), axis=0)
# Part 2: Neural Module Network
with tf.variable_scope('layout_execution'):
modules = Modules(config, kb, word_vecs, num_choices, embedding_mat)
self.modules = modules
# Recursion of modules
att_shape = [len(kb)]
# Forward declaration of module recursion
att_expr_decl = td.ForwardDeclaration(td.PyObjectType(),
td.TensorType(att_shape))
# _key_find
case_key_find = td.Record([('time_idx', td.Scalar(dtype='int32')),
('batch_idx', td.Scalar(dtype='int32'))])
case_key_find = case_key_find >> td.ScopedLayer(
modules.KeyFindModule, name_or_scope='KeyFindModule')
# _key_filter
case_key_filter = td.Record([('input_0', att_expr_decl()),
('time_idx', td.Scalar('int32')),
('batch_idx', td.Scalar('int32'))])
case_key_filter = case_key_filter >> td.ScopedLayer(
modules.KeyFilterModule, name_or_scope='KeyFilterModule')
recursion_cases = td.OneOf(
td.GetItem('module'),
{'_key_find': case_key_find,
'_key_filter': case_key_filter})
att_expr_decl.resolve_to(recursion_cases)
# _val_desc: output scores for choice (for valid expressions)
predicted_scores = td.Record([('input_0', recursion_cases),
('time_idx', td.Scalar('int32')),
('batch_idx', td.Scalar('int32'))])
predicted_scores = predicted_scores >> td.ScopedLayer(
modules.ValDescribeModule, name_or_scope='ValDescribeModule')
# For invalid expressions, define a dummy answer
# so that all answers have the same form
INVALID = assembler.INVALID_EXPR
dummy_scores = td.Void() >> td.FromTensor(
np.zeros(num_choices, np.float32))
output_scores = td.OneOf(
td.GetItem('module'),
{'_val_desc': predicted_scores,
INVALID: dummy_scores})
# compile and get the output scores
self.compiler = td.Compiler.create(output_scores)
self.scores = self.compiler.output_tensors[0]
# Regularization: Entropy + L2
self.entropy_reg = tf.reduce_mean(neg_entropy)
module_weights = [
v for v in tf.trainable_variables()
if (scope in v.op.name and v.op.name.endswith('weights'))
]
self.l2_reg = tf.add_n([tf.nn.l2_loss(v) for v in module_weights])
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import tensorflow as tf
class Modules:
def __init__(self, config, kb, word_vecs, num_choices, embedding_mat):
self.config = config
self.embedding_mat = embedding_mat
# kb has shape [N_kb, 3]
self.kb = kb
self.embed_keys_e, self.embed_keys_r, self.embed_vals_e = self.embed_kb()
# word_vecs has shape [T_decoder, N, D_txt]
self.word_vecs = word_vecs
self.num_choices = num_choices
def embed_kb(self):
keys_e, keys_r, vals_e = [], [], []
for idx_sub, idx_rel, idx_obj in self.kb:
keys_e.append(idx_sub)
keys_r.append(idx_rel)
vals_e.append(idx_obj)
embed_keys_e = tf.nn.embedding_lookup(self.embedding_mat, keys_e)
embed_keys_r = tf.nn.embedding_lookup(self.embedding_mat, keys_r)
embed_vals_e = tf.nn.embedding_lookup(self.embedding_mat, vals_e)
return embed_keys_e, embed_keys_r, embed_vals_e
def _slice_word_vecs(self, time_idx, batch_idx):
# this callable will be wrapped into a td.Function
# In TF Fold, batch_idx and time_idx are both [N_batch, 1] tensors
# time is highest dim in word_vecs
joint_index = tf.stack([time_idx, batch_idx], axis=1)
return tf.gather_nd(self.word_vecs, joint_index)
# All the layers are wrapped with td.ScopedLayer
def KeyFindModule(self,
time_idx,
batch_idx,
scope='KeyFindModule',
reuse=None):
# In TF Fold, batch_idx and time_idx are both [N_batch, 1] tensors
text_param = self._slice_word_vecs(time_idx, batch_idx)
# Mapping: embed_keys_e x text_param -> att
# Input:
# embed_keys_e: [N_kb, D_txt]
# text_param: [N, D_txt]
# Output:
# att: [N, N_kb]
#
# Implementation:
# 1. Elementwise multiplication between embed_key_e and text_param
# 2. L2-normalization
with tf.variable_scope(scope, reuse=reuse):
m = tf.matmul(text_param, self.embed_keys_e, transpose_b=True)
att = tf.nn.l2_normalize(m, dim=1)
return att
def KeyFilterModule(self,
input_0,
time_idx,
batch_idx,
scope='KeyFilterModule',
reuse=None):
att_0 = input_0
text_param = self._slice_word_vecs(time_idx, batch_idx)
# Mapping: and(embed_keys_r x text_param, att) -> att
# Input:
# embed_keys_r: [N_kb, D_txt]
# text_param: [N, D_txt]
# att_0: [N, N_kb]
# Output:
# att: [N, N_kb]
#
# Implementation:
# 1. Elementwise multiplication between embed_key_r and text_param
# 2. L2-normalization
# 3. Take the elementwise-min
with tf.variable_scope(scope, reuse=reuse):
m = tf.matmul(text_param, self.embed_keys_r, transpose_b=True)
att_1 = tf.nn.l2_normalize(m, dim=1)
att = tf.minimum(att_0, att_1)
return att
def ValDescribeModule(self,
input_0,
time_idx,
batch_idx,
scope='ValDescribeModule',
reuse=None):
att = input_0
# Mapping: att -> answer probs
# Input:
# embed_vals_e: [N_kb, D_txt]
# att: [N, N_kb]
# embedding_mat: [self.num_choices, D_txt]
# Output:
# answer_scores: [N, self.num_choices]
#
# Implementation:
# 1. Attention-weighted sum over values
# 2. Compute cosine similarity scores between the weighted sum and
# each candidate answer
with tf.variable_scope(scope, reuse=reuse):
# weighted_sum has shape [N, D_txt]
weighted_sum = tf.matmul(att, self.embed_vals_e)
# scores has shape [N, self.num_choices]
scores = tf.matmul(
weighted_sum,
tf.nn.l2_normalize(self.embedding_mat, dim=1),
transpose_b=True)
return scores
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import tensorflow as tf
from util.nn import fc_layer as fc
def _get_lstm_cell(num_layers, lstm_dim):
cell_list = [
tf.contrib.rnn.BasicLSTMCell(lstm_dim, state_is_tuple=True)
for _ in range(num_layers)
]
cell = tf.contrib.rnn.MultiRNNCell(cell_list, state_is_tuple=True)
return cell
class AttentionSeq2Seq:
def __init__(self,
config,
text_seq_batch,
seq_length_batch,
num_vocab_txt,
num_vocab_nmn,
EOS_token,
decoder_sampling,
embedding_mat,
use_gt_layout=None,
gt_layout_batch=None,
scope='encoder_decoder',
reuse=None):
self.T_decoder = config.T_decoder
self.encoder_num_vocab = num_vocab_txt
self.encoder_embed_dim = config.embed_dim_txt
self.decoder_num_vocab = num_vocab_nmn
self.decoder_embed_dim = config.embed_dim_nmn
self.lstm_dim = config.lstm_dim
self.num_layers = config.num_layers
self.EOS_token = EOS_token
self.decoder_sampling = decoder_sampling
self.embedding_mat = embedding_mat
with tf.variable_scope(scope, reuse=reuse):
self._build_encoder(text_seq_batch, seq_length_batch)
self._build_decoder(use_gt_layout, gt_layout_batch)
def _build_encoder(self,
text_seq_batch,
seq_length_batch,
scope='encoder',
reuse=None):
lstm_dim = self.lstm_dim
num_layers = self.num_layers
with tf.variable_scope(scope, reuse=reuse):
T = tf.shape(text_seq_batch)[0]
N = tf.shape(text_seq_batch)[1]
self.T_encoder = T
self.N = N
# text_seq has shape [T, N] and embedded_seq has shape [T, N, D]
embedded_seq = tf.nn.embedding_lookup(self.embedding_mat, text_seq_batch)
self.embedded_input_seq = embedded_seq
# The RNN
cell = _get_lstm_cell(num_layers, lstm_dim)
# encoder_outputs has shape [T, N, lstm_dim]
encoder_outputs, encoder_states = tf.nn.dynamic_rnn(
cell,
embedded_seq,
seq_length_batch,
dtype=tf.float32,
time_major=True,
scope='lstm')
self.encoder_outputs = encoder_outputs
self.encoder_states = encoder_states
# transform the encoder outputs for further attention alignments
# encoder_outputs_flat has shape [T, N, lstm_dim]
encoder_h_transformed = fc(
'encoder_h_transform',
tf.reshape(encoder_outputs, [-1, lstm_dim]),
output_dim=lstm_dim)
encoder_h_transformed = tf.reshape(encoder_h_transformed,
[T, N, lstm_dim])
self.encoder_h_transformed = encoder_h_transformed
# seq_not_finished is a shape [T, N, 1] tensor,
# where seq_not_finished[t, n]
# is 1 iff sequence n is not finished at time t, and 0 otherwise
seq_not_finished = tf.less(
tf.range(T)[:, tf.newaxis, tf.newaxis],
seq_length_batch[:, tf.newaxis])
seq_not_finished = tf.cast(seq_not_finished, tf.float32)
self.seq_not_finished = seq_not_finished
def _build_decoder(self,
use_gt_layout,
gt_layout_batch,
scope='decoder',
reuse=None):
# The main difference from before is that the decoders now takes another
# input (the attention) when computing the next step
# T_max is the maximum length of decoded sequence (including <eos>)
#
# This function is for decoding only. It performs greedy search or sampling.
# the first input is <go> (its embedding vector) and the subsequent inputs
# are the outputs from previous time step
# num_vocab does not include <go>
#
# use_gt_layout is None or a bool tensor, and gt_layout_batch is a tensor
# with shape [T_max, N].
# If use_gt_layout is not None, then when use_gt_layout is true, predict
# exactly the tokens in gt_layout_batch, regardless of actual probability.
# Otherwise, if sampling is True, sample from the token probability
# If sampling is False, do greedy decoding (beam size 1)
N = self.N
encoder_states = self.encoder_states
T_max = self.T_decoder
lstm_dim = self.lstm_dim
num_layers = self.num_layers
EOS_token = self.EOS_token
sampling = self.decoder_sampling
with tf.variable_scope(scope, reuse=reuse):
embedding_mat = tf.get_variable(
'embedding_mat', [self.decoder_num_vocab, self.decoder_embed_dim])
# we use a separate embedding for <go>, as it is only used in the
# beginning of the sequence
go_embedding = tf.get_variable('go_embedding',
[1, self.decoder_embed_dim])
with tf.variable_scope('att_prediction'):
v = tf.get_variable('v', [lstm_dim])
W_a = tf.get_variable(
'weights', [lstm_dim, lstm_dim],
initializer=tf.contrib.layers.xavier_initializer())
b_a = tf.get_variable(
'biases', lstm_dim, initializer=tf.constant_initializer(0.))
# The parameters to predict the next token
with tf.variable_scope('token_prediction'):
W_y = tf.get_variable(
'weights', [lstm_dim * 2, self.decoder_num_vocab],
initializer=tf.contrib.layers.xavier_initializer())
b_y = tf.get_variable(
'biases',
self.decoder_num_vocab,
initializer=tf.constant_initializer(0.))
# Attentional decoding
# Loop function is called at time t BEFORE the cell execution at time t,
# and its next_input is used as the input at time t (not t+1)
# c.f. https://www.tensorflow.org/api_docs/python/tf/nn/raw_rnn
mask_range = tf.reshape(
tf.range(self.decoder_num_vocab, dtype=tf.int32), [1, -1])
all_eos_pred = EOS_token * tf.ones([N], tf.int32)
all_one_prob = tf.ones([N], tf.float32)
all_zero_entropy = tf.zeros([N], tf.float32)
if use_gt_layout is not None:
gt_layout_mult = tf.cast(use_gt_layout, tf.int32)
pred_layout_mult = 1 - gt_layout_mult
def loop_fn(time, cell_output, cell_state, loop_state):
if cell_output is None: # time == 0
next_cell_state = encoder_states
next_input = tf.tile(go_embedding, [N, 1])
else: # time > 0
next_cell_state = cell_state
# compute the attention map over the input sequence
# a_raw has shape [T, N, 1]
att_raw = tf.reduce_sum(
tf.tanh(
tf.nn.xw_plus_b(cell_output, W_a, b_a) +
self.encoder_h_transformed) * v,
axis=2,
keep_dims=True)
# softmax along the first dimension (T) over not finished examples
# att has shape [T, N, 1]
att = tf.nn.softmax(att_raw, dim=0) * self.seq_not_finished
att = att / tf.reduce_sum(att, axis=0, keep_dims=True)
# d has shape [N, lstm_dim]
d2 = tf.reduce_sum(att * self.encoder_outputs, axis=0)
# token_scores has shape [N, num_vocab]
token_scores = tf.nn.xw_plus_b(
tf.concat([cell_output, d2], axis=1), W_y, b_y)
# predict the next token (behavior depending on parameters)
if sampling:
# predicted_token has shape [N]
logits = token_scores
predicted_token = tf.cast(
tf.reshape(tf.multinomial(token_scores, 1), [-1]), tf.int32)
else:
# predicted_token has shape [N]
predicted_token = tf.cast(tf.argmax(token_scores, 1), tf.int32)
if use_gt_layout is not None:
predicted_token = (gt_layout_batch[time - 1] * gt_layout_mult +
predicted_token * pred_layout_mult)
# token_prob has shape [N], the probability of the predicted token
# although token_prob is not needed for predicting the next token
# it is needed in output (for policy gradient training)
# [N, num_vocab]
# mask has shape [N, num_vocab]
mask = tf.equal(mask_range, tf.reshape(predicted_token, [-1, 1]))
all_token_probs = tf.nn.softmax(token_scores)
token_prob = tf.reduce_sum(
all_token_probs * tf.cast(mask, tf.float32), axis=1)
neg_entropy = tf.reduce_sum(
all_token_probs * tf.log(all_token_probs), axis=1)
# is_eos_predicted is a [N] bool tensor, indicating whether
# <eos> has already been predicted previously in each sequence
is_eos_predicted = loop_state[2]
predicted_token_old = predicted_token
# if <eos> has already been predicted, now predict <eos> with
# prob 1
predicted_token = tf.where(is_eos_predicted, all_eos_pred,
predicted_token)
token_prob = tf.where(is_eos_predicted, all_one_prob, token_prob)
neg_entropy = tf.where(is_eos_predicted, all_zero_entropy,
neg_entropy)
is_eos_predicted = tf.logical_or(is_eos_predicted,
tf.equal(predicted_token_old,
EOS_token))
# the prediction is from the cell output of the last step
# timestep (t-1), feed it as input into timestep t
next_input = tf.nn.embedding_lookup(embedding_mat, predicted_token)
elements_finished = tf.greater_equal(time, T_max)
# loop_state is a 5-tuple, representing
# 1) the predicted_tokens
# 2) the prob of predicted_tokens
# 3) whether <eos> has already been predicted
# 4) the negative entropy of policy (accumulated across timesteps)
# 5) the attention
if loop_state is None: # time == 0
# Write the predicted token into the output
predicted_token_array = tf.TensorArray(
dtype=tf.int32, size=T_max, infer_shape=False)
token_prob_array = tf.TensorArray(
dtype=tf.float32, size=T_max, infer_shape=False)
att_array = tf.TensorArray(
dtype=tf.float32, size=T_max, infer_shape=False)
next_loop_state = (predicted_token_array, token_prob_array, tf.zeros(
[N], dtype=tf.bool), tf.zeros([N], dtype=tf.float32), att_array)
else: # time > 0
t_write = time - 1
next_loop_state = (
loop_state[0].write(t_write, predicted_token),
loop_state[1].write(t_write, token_prob),
is_eos_predicted,
loop_state[3] + neg_entropy,
loop_state[4].write(t_write, att))
return (elements_finished, next_input, next_cell_state, cell_output,
next_loop_state)
# The RNN
cell = _get_lstm_cell(num_layers, lstm_dim)
_, _, decodes_ta = tf.nn.raw_rnn(cell, loop_fn, scope='lstm')
predicted_tokens = decodes_ta[0].stack()
token_probs = decodes_ta[1].stack()
neg_entropy = decodes_ta[3]
# atts has shape [T_decoder, T_encoder, N, 1]
atts = decodes_ta[4].stack()
self.atts = atts
# word_vec has shape [T_decoder, N, D]
word_vecs = tf.reduce_sum(atts * self.embedded_input_seq, axis=1)
predicted_tokens.set_shape([None, None])
token_probs.set_shape([None, None])
neg_entropy.set_shape([None])
word_vecs.set_shape([None, None, self.encoder_embed_dim])
self.predicted_tokens = predicted_tokens
self.token_probs = token_probs
self.neg_entropy = neg_entropy
self.word_vecs = word_vecs
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from collections import namedtuple
from Queue import Queue
import re
import threading
import numpy as np
import tensorflow as tf
Data = namedtuple('Data', ['X', 'Y', 'MultiYs', 'qid'])
class SampleBuilder:
def __init__(self, config):
self.config = config
self.kb_raw = self.read_kb()
self.data_raw = self.read_raw_data()
# dictionary of entities, normal words, and relations
self.dict_all = self.gen_dict()
self.reverse_dict_all = dict(
zip(self.dict_all.values(), self.dict_all.keys()))
tf.logging.info('size of dict: %d' % len(self.dict_all))
self.kb = self.build_kb()
self.data_all = self.build_samples()
def read_kb(self):
kb_raw = []
for line in file(self.config.KB_file):
sub, rel, obj = line.strip().split('|')
kb_raw.append((sub, rel, obj))
tf.logging.info('# of KB records: %d' % len(kb_raw))
return kb_raw
def read_raw_data(self):
data = dict()
for name in self.config.data_files:
raw = []
tf.logging.info(
'Reading data file {}'.format(self.config.data_files[name]))
for line in file(self.config.data_files[name]):
question, answers = line.strip().split('\t')
question = question.replace('],', ']') # ignore ',' in the template
raw.append((question, answers))
data[name] = raw
return data
def build_kb(self):
tf.logging.info('Indexing KB...')
kb = []
for sub, rel, obj in self.kb_raw:
kb.append([self.dict_all[sub], self.dict_all[rel], self.dict_all[obj]])
return kb
def gen_dict(self):
s = set()
for sub, rel, obj in self.kb_raw:
s.add(sub)
s.add(rel)
s.add(obj)
for name in self.data_raw:
for question, answers in self.data_raw[name]:
normal = re.split('\[[^\]]+\]', question)
for phrase in normal:
for word in phrase.split():
s.add(word)
s = list(s)
d = {s[idx]: idx for idx in range(len(s))}
return d
def build_samples(self):
def map_entity_idx(text):
entities = re.findall('\[[^\]]+\]', text)
for entity in entities:
entity = entity[1:-1]
index = self.dict_all[entity]
text = text.replace('[%s]' % entity, '@%d' % index)
return text
data_all = dict()
for name in self.data_raw:
X, Y, MultiYs, qid = [], [], [], []
for i, (question, answers) in enumerate(self.data_raw[name]):
qdata, labels = [], []
question = map_entity_idx(question)
for word in question.split():
if word[0] == '@':
qdata.append(int(word[1:]))
else:
qdata.append(self.dict_all[word])
for answer in answers.split('|'):
labels.append(self.dict_all[answer])
if len(qdata) > self.config.T_encoder:
self.config.T_encoder = len(qdata)
for label in labels:
X.append(qdata)
Y.append(label)
MultiYs.append(set(labels))
qid.append(i)
data_all[name] = Data(X=X, Y=Y, MultiYs=MultiYs, qid=qid)
return data_all
def _run_prefetch(prefetch_queue, batch_loader, data, shuffle, one_pass,
config):
assert len(data.X) == len(data.Y) == len(data.MultiYs) == len(data.qid)
num_samples = len(data.X)
batch_size = config.batch_size
n_sample = 0
fetch_order = config.rng.permutation(num_samples)
while True:
sample_ids = fetch_order[n_sample:n_sample + batch_size]
batch = batch_loader.load_one_batch(sample_ids)
prefetch_queue.put(batch, block=True)
n_sample += len(sample_ids)
if n_sample >= num_samples:
if one_pass:
prefetch_queue.put(None, block=True)
n_sample = 0
if shuffle:
fetch_order = config.rng.permutation(num_samples)
class DataReader:
def __init__(self,
config,
data,
assembler,
shuffle=True,
one_pass=False,
prefetch_num=10):
self.config = config
self.data = data
self.assembler = assembler
self.batch_loader = BatchLoader(self.config,
self.data, self.assembler)
self.shuffle = shuffle
self.one_pass = one_pass
self.prefetch_queue = Queue(maxsize=prefetch_num)
self.prefetch_thread = threading.Thread(target=_run_prefetch,
args=(self.prefetch_queue,
self.batch_loader, self.data,
self.shuffle, self.one_pass,
self.config))
self.prefetch_thread.daemon = True
self.prefetch_thread.start()
def batches(self):
while True:
if self.prefetch_queue.empty():
tf.logging.warning('Waiting for data loading (IO is slow)...')
batch = self.prefetch_queue.get(block=True)
if batch is None:
assert self.one_pass
tf.logging.info('One pass finished!')
raise StopIteration()
yield batch
class BatchLoader:
def __init__(self, config,
data, assembler):
self.config = config
self.data = data
self.assembler = assembler
self.T_encoder = config.T_encoder
self.T_decoder = config.T_decoder
tf.logging.info('T_encoder: %d' % self.T_encoder)
tf.logging.info('T_decoder: %d' % self.T_decoder)
tf.logging.info('batch size: %d' % self.config.batch_size)
self.gt_layout_tokens = config.gt_layout_tokens
def load_one_batch(self, sample_ids):
actual_batch_size = len(sample_ids)
input_seq_batch = np.zeros((self.T_encoder, actual_batch_size), np.int32)
seq_len_batch = np.zeros(actual_batch_size, np.int32)
ans_label_batch = np.zeros(actual_batch_size, np.int32)
ans_set_labels_list = [None] * actual_batch_size
question_id_list = [None] * actual_batch_size
gt_layout_batch = np.zeros((self.T_decoder, actual_batch_size), np.int32)
for batch_i in range(actual_batch_size):
idx = sample_ids[batch_i]
seq_len = len(self.data.X[idx])
seq_len_batch[batch_i] = seq_len
input_seq_batch[:seq_len, batch_i] = self.data.X[idx]
ans_label_batch[batch_i] = self.data.Y[idx]
ans_set_labels_list[batch_i] = self.data.MultiYs[idx]
question_id_list[batch_i] = self.data.qid[idx]
gt_layout_batch[:, batch_i] = self.assembler.module_list2tokens(
self.gt_layout_tokens, self.T_decoder)
batch = dict(input_seq_batch=input_seq_batch,
seq_len_batch=seq_len_batch,
ans_label_batch=ans_label_batch,
gt_layout_batch=gt_layout_batch,
ans_set_labels_list=ans_set_labels_list,
question_id_list=question_id_list)
return batch
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from datetime import datetime
import json
import logging
import os
import tensorflow as tf
import tensorflow.contrib.slim as slim
def prepare_dirs_and_logger(config):
formatter = logging.Formatter('%(asctime)s:%(levelname)s::%(message)s')
logger = logging.getLogger('tensorflow')
for hdlr in logger.handlers:
logger.removeHandler(hdlr)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(tf.logging.INFO)
config.log_dir = os.path.join(config.exp_dir, config.log_dir,
config.train_tag)
config.model_dir = os.path.join(config.exp_dir, config.model_dir,
config.train_tag)
config.output_dir = os.path.join(config.exp_dir, config.output_dir,
config.train_tag)
for path in [
config.log_dir, config.model_dir, config.output_dir
]:
if not os.path.exists(path):
os.makedirs(path)
config.data_files = {
'train': os.path.join(config.data_dir, config.train_data_file),
'dev': os.path.join(config.data_dir, config.dev_data_file),
'test': os.path.join(config.data_dir, config.test_data_file)
}
return config
def get_time():
return datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
def show_all_variables():
model_vars = tf.trainable_variables()
slim.model_analyzer.analyze_vars(model_vars, print_info=True)
def save_config(config):
param_path = os.path.join(config.model_dir, 'params.json')
tf.logging.info('log dir: %s' % config.log_dir)
tf.logging.info('model dir: %s' % config.model_dir)
tf.logging.info('param path: %s' % param_path)
tf.logging.info('output dir: %s' % config.output_dir)
with open(param_path, 'w') as f:
f.write(json.dumps(config.__dict__, indent=4, sort_keys=True))
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import tensorflow as tf
def fc_layer(name,
bottom,
output_dim,
bias_term=True,
weights_initializer=None,
biases_initializer=None,
reuse=None):
# flatten bottom input
shape = bottom.get_shape().as_list()
input_dim = 1
for d in shape[1:]:
input_dim *= d
flat_bottom = tf.reshape(bottom, [-1, input_dim])
# weights and biases variables
with tf.variable_scope(name, reuse=reuse):
# initialize the variables
if weights_initializer is None:
weights_initializer = tf.contrib.layers.xavier_initializer()
if bias_term and biases_initializer is None:
biases_initializer = tf.constant_initializer(0.)
# weights has shape [input_dim, output_dim]
weights = tf.get_variable(
'weights', [input_dim, output_dim], initializer=weights_initializer)
if bias_term:
biases = tf.get_variable(
'biases', output_dim, initializer=biases_initializer)
if not reuse:
tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES,
tf.nn.l2_loss(weights))
if bias_term:
fc = tf.nn.xw_plus_b(flat_bottom, weights, biases)
else:
fc = tf.matmul(flat_bottom, weights)
return fc
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment