Commit 480f0ebf authored by Yuyu Zhang's avatar Yuyu Zhang
Browse files

commit code

parent 38f008a9
...@@ -33,9 +33,11 @@ The code in this repository is based on the original ...@@ -33,9 +33,11 @@ The code in this repository is based on the original
## Data ## Data
1. Download the [MetaQA dataset](https://goo.gl/f3AmcY). Read the documents 1. Download the [MetaQA dataset](https://goo.gl/f3AmcY). Click the button
there for dataset details. `MetaQA` and then click `Download` in the drop-down list. Extract the zip
2. Put the MetaQA folder in the root directory of this repository. file after downloading completed. Read the documents there for dataset
details.
2. Move the `MetaQA` folder to the root directory of this repository.
## How to use this code ## How to use this code
......
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import argparse
import os
def str2bool(v):
return v.lower() in ('true', '1')
def add_argument_group(name):
arg = parser.add_argument_group(name)
arg_lists.append(arg)
return arg
def get_config():
config, unparsed = parser.parse_known_args()
return config, unparsed
arg_lists = []
parser = argparse.ArgumentParser()
work_dir = os.path.abspath(os.path.join(__file__, '../../'))
net_arg = add_argument_group('Network')
net_arg.add_argument('--lstm_dim', type=int, default=128)
net_arg.add_argument('--num_layers', type=int, default=1)
net_arg.add_argument('--embed_dim_txt', type=int, default=128)
net_arg.add_argument('--embed_dim_nmn', type=int, default=128)
net_arg.add_argument(
'--T_encoder', type=int, default=0) # will be updated when reading data
net_arg.add_argument('--T_decoder', type=int, default=5)
train_arg = add_argument_group('Training')
train_arg.add_argument('--train_tag', type=str, default='n2nmn')
train_arg.add_argument('--batch_size', type=int, default=128)
train_arg.add_argument('--max_iter', type=int, default=1000000)
train_arg.add_argument('--weight_decay', type=float, default=1e-5)
train_arg.add_argument('--baseline_decay', type=float, default=0.99)
train_arg.add_argument('--max_grad_norm', type=float, default=10)
train_arg.add_argument('--random_seed', type=int, default=123)
data_arg = add_argument_group('Data')
data_path = work_dir + '/MetaQA/'
data_arg.add_argument('--KB_file', type=str, default=data_path + 'kb.txt')
data_arg.add_argument(
'--data_dir', type=str, default=data_path + '1-hop/vanilla/')
data_arg.add_argument('--train_data_file', type=str, default='qa_train.txt')
data_arg.add_argument('--dev_data_file', type=str, default='qa_dev.txt')
data_arg.add_argument('--test_data_file', type=str, default='qa_test.txt')
exp_arg = add_argument_group('Experiment')
exp_path = work_dir + '/exp_1_hop/'
exp_arg.add_argument('--exp_dir', type=str, default=exp_path)
log_arg = add_argument_group('Log')
log_arg.add_argument('--log_dir', type=str, default='logs')
log_arg.add_argument('--log_interval', type=int, default=1000)
log_arg.add_argument('--num_log_samples', type=int, default=3)
log_arg.add_argument(
'--log_level', type=str, default='INFO', choices=['INFO', 'DEBUG', 'WARN'])
io_arg = add_argument_group('IO')
io_arg.add_argument('--model_dir', type=str, default='model')
io_arg.add_argument('--snapshot_interval', type=int, default=1000)
io_arg.add_argument('--output_dir', type=str, default='output')
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
import sys
sys.path.append(os.path.abspath(os.path.join(__file__, '../../')))
import numpy as np
import tensorflow as tf
from config import get_config
from model_n2nmn.assembler import Assembler
from model_n2nmn.model import Model
from util.data_reader import DataReader
from util.data_reader import SampleBuilder
from util.misc import prepare_dirs_and_logger
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_string('snapshot_name', '00001000', 'snapshot file name')
def main(_):
config = prepare_dirs_and_logger(config_raw)
rng = np.random.RandomState(config.random_seed)
tf.set_random_seed(config.random_seed)
config.rng = rng
config.module_names = ['_key_find', '_key_filter', '_val_desc', '<eos>']
config.gt_layout_tokens = ['_key_find', '_key_filter', '_val_desc', '<eos>']
assembler = Assembler(config)
sample_builder = SampleBuilder(config)
config = sample_builder.config # update T_encoder according to data
data_test = sample_builder.data_all['test']
data_reader_test = DataReader(
config, data_test, assembler, shuffle=False, one_pass=True)
num_vocab_txt = len(sample_builder.dict_all)
num_vocab_nmn = len(assembler.module_names)
num_choices = len(sample_builder.dict_all)
# Network inputs
text_seq_batch = tf.placeholder(tf.int32, [None, None])
seq_len_batch = tf.placeholder(tf.int32, [None])
# The model
model = Model(
config,
sample_builder.kb,
text_seq_batch,
seq_len_batch,
num_vocab_txt=num_vocab_txt,
num_vocab_nmn=num_vocab_nmn,
EOS_idx=assembler.EOS_idx,
num_choices=num_choices,
decoder_sampling=False)
compiler = model.compiler
scores = model.scores
sess = tf.Session()
sess.run(tf.global_variables_initializer())
snapshot_file = os.path.join(config.model_dir, FLAGS.snapshot_name)
tf.logging.info('Snapshot file: %s' % snapshot_file)
snapshot_saver = tf.train.Saver()
snapshot_saver.restore(sess, snapshot_file)
# Evaluation metrics
num_questions = len(data_test.Y)
tf.logging.info('# of test questions: %d' % num_questions)
answer_correct = 0
layout_correct = 0
layout_valid = 0
for batch in data_reader_test.batches():
# set up input and output tensors
h = sess.partial_run_setup(
fetches=[model.predicted_tokens, scores],
feeds=[text_seq_batch, seq_len_batch, compiler.loom_input_tensor])
# Part 1: Generate module layout
tokens = sess.partial_run(
h,
fetches=model.predicted_tokens,
feed_dict={
text_seq_batch: batch['input_seq_batch'],
seq_len_batch: batch['seq_len_batch']
})
# Compute accuracy of the predicted layout
gt_tokens = batch['gt_layout_batch']
layout_correct += np.sum(
np.all(
np.logical_or(tokens == gt_tokens, gt_tokens == assembler.EOS_idx),
axis=0))
# Assemble the layout tokens into network structure
expr_list, expr_validity_array = assembler.assemble(tokens)
layout_valid += np.sum(expr_validity_array)
labels = batch['ans_label_batch']
# Build TensorFlow Fold input for NMN
expr_feed = compiler.build_feed_dict(expr_list)
# Part 2: Run NMN and learning steps
scores_val = sess.partial_run(h, scores, feed_dict=expr_feed)
# Compute accuracy
predictions = np.argmax(scores_val, axis=1)
answer_correct += np.sum(
np.logical_and(expr_validity_array, predictions == labels))
answer_accuracy = answer_correct * 1.0 / num_questions
layout_accuracy = layout_correct * 1.0 / num_questions
layout_validity = layout_valid * 1.0 / num_questions
tf.logging.info('test answer accuracy = %f, '
'test layout accuracy = %f, '
'test layout validity = %f' %
(answer_accuracy, layout_accuracy, layout_validity))
if __name__ == '__main__':
config_raw, unparsed = get_config()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
import sys
sys.path.append(os.path.abspath(os.path.join(__file__, '../../')))
import numpy as np
import tensorflow as tf
from config import get_config
from model_n2nmn.assembler import Assembler
from model_n2nmn.model import Model
from util.data_reader import DataReader
from util.data_reader import SampleBuilder
from util.misc import prepare_dirs_and_logger
from util.misc import save_config
from util.misc import show_all_variables
def main(_):
config = prepare_dirs_and_logger(config_raw)
save_config(config)
rng = np.random.RandomState(config.random_seed)
tf.set_random_seed(config.random_seed)
config.rng = rng
config.module_names = ['_key_find', '_key_filter', '_val_desc', '<eos>']
config.gt_layout_tokens = ['_key_find', '_key_filter', '_val_desc', '<eos>']
assembler = Assembler(config)
sample_builder = SampleBuilder(config)
config = sample_builder.config # update T_encoder according to data
data_train = sample_builder.data_all['train']
data_reader_train = DataReader(
config, data_train, assembler, shuffle=True, one_pass=False)
num_vocab_txt = len(sample_builder.dict_all)
num_vocab_nmn = len(assembler.module_names)
num_choices = len(sample_builder.dict_all)
# Network inputs
text_seq_batch = tf.placeholder(tf.int32, [None, None])
seq_len_batch = tf.placeholder(tf.int32, [None])
ans_label_batch = tf.placeholder(tf.int32, [None])
use_gt_layout = tf.constant(True, dtype=tf.bool)
gt_layout_batch = tf.placeholder(tf.int32, [None, None])
# The model for training
model = Model(
config,
sample_builder.kb,
text_seq_batch,
seq_len_batch,
num_vocab_txt=num_vocab_txt,
num_vocab_nmn=num_vocab_nmn,
EOS_idx=assembler.EOS_idx,
num_choices=num_choices,
decoder_sampling=True,
use_gt_layout=use_gt_layout,
gt_layout_batch=gt_layout_batch)
compiler = model.compiler
scores = model.scores
log_seq_prob = model.log_seq_prob
# Loss function
softmax_loss_per_sample = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=scores, labels=ans_label_batch)
# The final per-sample loss, which is loss for valid expr
# and invalid_expr_loss for invalid expr
final_loss_per_sample = softmax_loss_per_sample # All exprs are valid
avg_sample_loss = tf.reduce_mean(final_loss_per_sample)
seq_likelihood_loss = tf.reduce_mean(-log_seq_prob)
total_training_loss = seq_likelihood_loss + avg_sample_loss
total_loss = total_training_loss + config.weight_decay * model.l2_reg
# Train with Adam optimizer
solver = tf.train.AdamOptimizer()
gradients = solver.compute_gradients(total_loss)
# Clip gradient by L2 norm
gradients = [(tf.clip_by_norm(g, config.max_grad_norm), v)
for g, v in gradients]
solver_op = solver.apply_gradients(gradients)
# Training operation
with tf.control_dependencies([solver_op]):
train_step = tf.constant(0)
# Write summary to TensorBoard
log_writer = tf.summary.FileWriter(config.log_dir, tf.get_default_graph())
loss_ph = tf.placeholder(tf.float32, [])
entropy_ph = tf.placeholder(tf.float32, [])
accuracy_ph = tf.placeholder(tf.float32, [])
summary_train = [
tf.summary.scalar('avg_sample_loss', loss_ph),
tf.summary.scalar('entropy', entropy_ph),
tf.summary.scalar('avg_accuracy', accuracy_ph)
]
log_step_train = tf.summary.merge(summary_train)
# Training
sess = tf.Session()
sess.run(tf.global_variables_initializer())
snapshot_saver = tf.train.Saver(max_to_keep=None) # keep all snapshots
show_all_variables()
avg_accuracy = 0
accuracy_decay = 0.99
for n_iter, batch in enumerate(data_reader_train.batches()):
if n_iter >= config.max_iter:
break
# set up input and output tensors
h = sess.partial_run_setup(
fetches=[
model.predicted_tokens, model.entropy_reg, scores, avg_sample_loss,
train_step
],
feeds=[
text_seq_batch, seq_len_batch, gt_layout_batch,
compiler.loom_input_tensor, ans_label_batch
])
# Part 1: Generate module layout
tokens, entropy_reg_val = sess.partial_run(
h,
fetches=(model.predicted_tokens, model.entropy_reg),
feed_dict={
text_seq_batch: batch['input_seq_batch'],
seq_len_batch: batch['seq_len_batch'],
gt_layout_batch: batch['gt_layout_batch']
})
# Assemble the layout tokens into network structure
expr_list, expr_validity_array = assembler.assemble(tokens)
# all exprs should be valid (since they are ground-truth)
assert np.all(expr_validity_array)
labels = batch['ans_label_batch']
# Build TensorFlow Fold input for NMN
expr_feed = compiler.build_feed_dict(expr_list)
expr_feed[ans_label_batch] = labels
# Part 2: Run NMN and learning steps
scores_val, avg_sample_loss_val, _ = sess.partial_run(
h, fetches=(scores, avg_sample_loss, train_step), feed_dict=expr_feed)
# Compute accuracy
predictions = np.argmax(scores_val, axis=1)
accuracy = np.mean(
np.logical_and(expr_validity_array, predictions == labels))
avg_accuracy += (1 - accuracy_decay) * (accuracy - avg_accuracy)
# Add to TensorBoard summary
if (n_iter + 1) % config.log_interval == 0:
tf.logging.info('iter = %d\n\t'
'loss = %f, accuracy (cur) = %f, '
'accuracy (avg) = %f, entropy = %f' %
(n_iter + 1, avg_sample_loss_val, accuracy, avg_accuracy,
-entropy_reg_val))
summary = sess.run(
fetches=log_step_train,
feed_dict={
loss_ph: avg_sample_loss_val,
entropy_ph: -entropy_reg_val,
accuracy_ph: avg_accuracy
})
log_writer.add_summary(summary, n_iter + 1)
# Save snapshot
if (n_iter + 1) % config.snapshot_interval == 0:
snapshot_file = os.path.join(config.model_dir, '%08d' % (n_iter + 1))
snapshot_saver.save(sess, snapshot_file, write_meta_graph=False)
tf.logging.info('Snapshot saved to %s' % snapshot_file)
tf.logging.info('Run finished.')
if __name__ == '__main__':
config_raw, unparsed = get_config()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
# the number of attention input to each module
_module_input_num = {
'_key_find': 0,
'_key_filter': 1,
'_val_desc': 1}
_module_output_type = {
'_key_find': 'att',
'_key_filter': 'att',
'_val_desc': 'ans'
}
INVALID_EXPR = 'INVALID_EXPR'
class Assembler:
def __init__(self, config):
# read the module list, and record the index of each module and <eos>
self.module_names = config.module_names
# find the index of <eos>
for n_s in range(len(self.module_names)):
if self.module_names[n_s] == '<eos>':
self.EOS_idx = n_s
break
# build a dictionary from module name to token index
self.name2idx_dict = {
name: n_s
for n_s, name in enumerate(self.module_names)
}
def module_list2tokens(self, module_list, max_len=None):
layout_tokens = [self.name2idx_dict[name] for name in module_list]
if max_len is not None:
if len(module_list) >= max_len:
raise ValueError('Not enough time steps to add <eos>')
layout_tokens += [self.EOS_idx] * (max_len - len(module_list))
return layout_tokens
def _layout_tokens2str(self, layout_tokens):
return ' '.join([self.module_names[idx] for idx in layout_tokens])
def _invalid_expr(self, layout_tokens, error_str):
return {
'module': INVALID_EXPR,
'expr_str': self._layout_tokens2str(layout_tokens),
'error': error_str
}
def _assemble_layout_tokens(self, layout_tokens, batch_idx):
# Every module takes a time_idx as the index from LSTM hidden states
# (even if it doesn't need it, like _and), and different arity of
# attention inputs. The output type can be either attention or answer
#
# The final assembled expression for each instance is as follows:
# expr_type :=
# {'module': '_find', 'output_type': 'att', 'time_idx': idx}
# | {'module': '_relocate', 'output_type': 'att', 'time_idx': idx,
# 'inputs_0': <expr_type>}
# | {'module': '_and', 'output_type': 'att', 'time_idx': idx,
# 'inputs_0': <expr_type>, 'inputs_1': <expr_type>)}
# | {'module': '_describe', 'output_type': 'ans', 'time_idx': idx,
# 'inputs_0': <expr_type>}
# | {'module': INVALID_EXPR, 'expr_str': '...', 'error': '...',
# 'assembly_loss': <float32>} (for invalid expressions)
#
# A valid layout must contain <eos>. Assembly fails if it doesn't.
if not np.any(layout_tokens == self.EOS_idx):
return self._invalid_expr(layout_tokens, 'cannot find <eos>')
# Decoding Reverse Polish Notation with a stack
decoding_stack = []
for t in range(len(layout_tokens)):
# decode a module/operation
module_idx = layout_tokens[t]
if module_idx == self.EOS_idx:
break
module_name = self.module_names[module_idx]
expr = {
'module': module_name,
'output_type': _module_output_type[module_name],
'time_idx': t,
'batch_idx': batch_idx
}
input_num = _module_input_num[module_name]
# Check if there are enough input in the stack
if len(decoding_stack) < input_num:
# Invalid expression. Not enough input.
return self._invalid_expr(layout_tokens,
'not enough input for ' + module_name)
# Get the input from stack
for n_input in range(input_num - 1, -1, -1):
stack_top = decoding_stack.pop()
if stack_top['output_type'] != 'att':
# Invalid expression. Input must be attention
return self._invalid_expr(layout_tokens,
'input incompatible for ' + module_name)
expr['input_%d' % n_input] = stack_top
decoding_stack.append(expr)
# After decoding the reverse polish expression, there should be exactly
# one expression in the stack
if len(decoding_stack) != 1:
return self._invalid_expr(
layout_tokens,
'final stack size not equal to 1 (%d remains)' % len(decoding_stack))
result = decoding_stack[0]
# The result type should be answer, not attention
if result['output_type'] != 'ans':
return self._invalid_expr(layout_tokens,
'result type must be ans, not att')
return result
def assemble(self, layout_tokens_batch):
# layout_tokens_batch is a numpy array with shape [max_dec_len, batch_size],
# containing module tokens and <eos>, in Reverse Polish Notation.
_, batch_size = layout_tokens_batch.shape
expr_list = [
self._assemble_layout_tokens(layout_tokens_batch[:, batch_i], batch_i)
for batch_i in range(batch_size)
]
expr_validity = np.array(
[expr['module'] != INVALID_EXPR for expr in expr_list], np.bool)
return expr_list, expr_validity
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import tensorflow as tf
import tensorflow_fold as td
from model_n2nmn import netgen_att
from model_n2nmn import assembler
from model_n2nmn.modules import Modules
class Model:
def __init__(self,
config,
kb,
text_seq_batch,
seq_length_batch,
num_vocab_txt,
num_vocab_nmn,
EOS_idx,
num_choices,
decoder_sampling,
use_gt_layout=None,
gt_layout_batch=None,
scope='neural_module_network',
reuse=None):
with tf.variable_scope(scope, reuse=reuse):
# Part 1: Seq2seq RNN to generate module layout tokens
embedding_mat = tf.get_variable(
'embedding_mat', [num_vocab_txt, config.embed_dim_txt],
initializer=tf.contrib.layers.xavier_initializer())
with tf.variable_scope('layout_generation'):
att_seq2seq = netgen_att.AttentionSeq2Seq(
config, text_seq_batch, seq_length_batch, num_vocab_txt,
num_vocab_nmn, EOS_idx, decoder_sampling, embedding_mat,
use_gt_layout, gt_layout_batch)
self.att_seq2seq = att_seq2seq
predicted_tokens = att_seq2seq.predicted_tokens
token_probs = att_seq2seq.token_probs
word_vecs = att_seq2seq.word_vecs
neg_entropy = att_seq2seq.neg_entropy
self.atts = att_seq2seq.atts
self.predicted_tokens = predicted_tokens
self.token_probs = token_probs
self.word_vecs = word_vecs
self.neg_entropy = neg_entropy
# log probability of each generated sequence
self.log_seq_prob = tf.reduce_sum(tf.log(token_probs), axis=0)
# Part 2: Neural Module Network
with tf.variable_scope('layout_execution'):
modules = Modules(config, kb, word_vecs, num_choices, embedding_mat)
self.modules = modules
# Recursion of modules
att_shape = [len(kb)]
# Forward declaration of module recursion
att_expr_decl = td.ForwardDeclaration(td.PyObjectType(),
td.TensorType(att_shape))
# _key_find
case_key_find = td.Record([('time_idx', td.Scalar(dtype='int32')),
('batch_idx', td.Scalar(dtype='int32'))])
case_key_find = case_key_find >> td.ScopedLayer(
modules.KeyFindModule, name_or_scope='KeyFindModule')
# _key_filter
case_key_filter = td.Record([('input_0', att_expr_decl()),
('time_idx', td.Scalar('int32')),
('batch_idx', td.Scalar('int32'))])
case_key_filter = case_key_filter >> td.ScopedLayer(
modules.KeyFilterModule, name_or_scope='KeyFilterModule')
recursion_cases = td.OneOf(
td.GetItem('module'),
{'_key_find': case_key_find,
'_key_filter': case_key_filter})
att_expr_decl.resolve_to(recursion_cases)
# _val_desc: output scores for choice (for valid expressions)
predicted_scores = td.Record([('input_0', recursion_cases),
('time_idx', td.Scalar('int32')),
('batch_idx', td.Scalar('int32'))])
predicted_scores = predicted_scores >> td.ScopedLayer(
modules.ValDescribeModule, name_or_scope='ValDescribeModule')
# For invalid expressions, define a dummy answer
# so that all answers have the same form
INVALID = assembler.INVALID_EXPR
dummy_scores = td.Void() >> td.FromTensor(
np.zeros(num_choices, np.float32))
output_scores = td.OneOf(
td.GetItem('module'),
{'_val_desc': predicted_scores,
INVALID: dummy_scores})
# compile and get the output scores
self.compiler = td.Compiler.create(output_scores)
self.scores = self.compiler.output_tensors[0]
# Regularization: Entropy + L2
self.entropy_reg = tf.reduce_mean(neg_entropy)
module_weights = [
v for v in tf.trainable_variables()
if (scope in v.op.name and v.op.name.endswith('weights'))
]
self.l2_reg = tf.add_n([tf.nn.l2_loss(v) for v in module_weights])
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import tensorflow as tf
class Modules:
def __init__(self, config, kb, word_vecs, num_choices, embedding_mat):
self.config = config
self.embedding_mat = embedding_mat
# kb has shape [N_kb, 3]
self.kb = kb
self.embed_keys_e, self.embed_keys_r, self.embed_vals_e = self.embed_kb()
# word_vecs has shape [T_decoder, N, D_txt]
self.word_vecs = word_vecs
self.num_choices = num_choices
def embed_kb(self):
keys_e, keys_r, vals_e = [], [], []
for idx_sub, idx_rel, idx_obj in self.kb:
keys_e.append(idx_sub)
keys_r.append(idx_rel)
vals_e.append(idx_obj)
embed_keys_e = tf.nn.embedding_lookup(self.embedding_mat, keys_e)
embed_keys_r = tf.nn.embedding_lookup(self.embedding_mat, keys_r)
embed_vals_e = tf.nn.embedding_lookup(self.embedding_mat, vals_e)
return embed_keys_e, embed_keys_r, embed_vals_e
def _slice_word_vecs(self, time_idx, batch_idx):
# this callable will be wrapped into a td.Function
# In TF Fold, batch_idx and time_idx are both [N_batch, 1] tensors
# time is highest dim in word_vecs
joint_index = tf.stack([time_idx, batch_idx], axis=1)
return tf.gather_nd(self.word_vecs, joint_index)
# All the layers are wrapped with td.ScopedLayer
def KeyFindModule(self,
time_idx,
batch_idx,
scope='KeyFindModule',
reuse=None):
# In TF Fold, batch_idx and time_idx are both [N_batch, 1] tensors
text_param = self._slice_word_vecs(time_idx, batch_idx)
# Mapping: embed_keys_e x text_param -> att
# Input:
# embed_keys_e: [N_kb, D_txt]
# text_param: [N, D_txt]
# Output:
# att: [N, N_kb]
#
# Implementation:
# 1. Elementwise multiplication between embed_key_e and text_param
# 2. L2-normalization
with tf.variable_scope(scope, reuse=reuse):
m = tf.matmul(text_param, self.embed_keys_e, transpose_b=True)
att = tf.nn.l2_normalize(m, dim=1)
return att
def KeyFilterModule(self,
input_0,
time_idx,
batch_idx,
scope='KeyFilterModule',
reuse=None):
att_0 = input_0
text_param = self._slice_word_vecs(time_idx, batch_idx)
# Mapping: and(embed_keys_r x text_param, att) -> att
# Input:
# embed_keys_r: [N_kb, D_txt]
# text_param: [N, D_txt]
# att_0: [N, N_kb]
# Output:
# att: [N, N_kb]
#
# Implementation:
# 1. Elementwise multiplication between embed_key_r and text_param
# 2. L2-normalization
# 3. Take the elementwise-min
with tf.variable_scope(scope, reuse=reuse):
m = tf.matmul(text_param, self.embed_keys_r, transpose_b=True)
att_1 = tf.nn.l2_normalize(m, dim=1)
att = tf.minimum(att_0, att_1)
return att
def ValDescribeModule(self,
input_0,
time_idx,
batch_idx,
scope='ValDescribeModule',
reuse=None):
att = input_0
# Mapping: att -> answer probs
# Input:
# embed_vals_e: [N_kb, D_txt]
# att: [N, N_kb]
# embedding_mat: [self.num_choices, D_txt]
# Output:
# answer_scores: [N, self.num_choices]
#
# Implementation:
# 1. Attention-weighted sum over values
# 2. Compute cosine similarity scores between the weighted sum and
# each candidate answer
with tf.variable_scope(scope, reuse=reuse):
# weighted_sum has shape [N, D_txt]
weighted_sum = tf.matmul(att, self.embed_vals_e)
# scores has shape [N, self.num_choices]
scores = tf.matmul(
weighted_sum,
tf.nn.l2_normalize(self.embedding_mat, dim=1),
transpose_b=True)
return scores
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import tensorflow as tf
from util.nn import fc_layer as fc
def _get_lstm_cell(num_layers, lstm_dim):
cell_list = [
tf.contrib.rnn.BasicLSTMCell(lstm_dim, state_is_tuple=True)
for _ in range(num_layers)
]
cell = tf.contrib.rnn.MultiRNNCell(cell_list, state_is_tuple=True)
return cell
class AttentionSeq2Seq:
def __init__(self,
config,
text_seq_batch,
seq_length_batch,
num_vocab_txt,
num_vocab_nmn,
EOS_token,
decoder_sampling,
embedding_mat,
use_gt_layout=None,
gt_layout_batch=None,
scope='encoder_decoder',
reuse=None):
self.T_decoder = config.T_decoder
self.encoder_num_vocab = num_vocab_txt
self.encoder_embed_dim = config.embed_dim_txt
self.decoder_num_vocab = num_vocab_nmn
self.decoder_embed_dim = config.embed_dim_nmn
self.lstm_dim = config.lstm_dim
self.num_layers = config.num_layers
self.EOS_token = EOS_token
self.decoder_sampling = decoder_sampling
self.embedding_mat = embedding_mat
with tf.variable_scope(scope, reuse=reuse):
self._build_encoder(text_seq_batch, seq_length_batch)
self._build_decoder(use_gt_layout, gt_layout_batch)
def _build_encoder(self,
text_seq_batch,
seq_length_batch,
scope='encoder',
reuse=None):
lstm_dim = self.lstm_dim
num_layers = self.num_layers
with tf.variable_scope(scope, reuse=reuse):
T = tf.shape(text_seq_batch)[0]
N = tf.shape(text_seq_batch)[1]
self.T_encoder = T
self.N = N
# text_seq has shape [T, N] and embedded_seq has shape [T, N, D]
embedded_seq = tf.nn.embedding_lookup(self.embedding_mat, text_seq_batch)
self.embedded_input_seq = embedded_seq
# The RNN
cell = _get_lstm_cell(num_layers, lstm_dim)
# encoder_outputs has shape [T, N, lstm_dim]
encoder_outputs, encoder_states = tf.nn.dynamic_rnn(
cell,
embedded_seq,
seq_length_batch,
dtype=tf.float32,
time_major=True,
scope='lstm')
self.encoder_outputs = encoder_outputs
self.encoder_states = encoder_states
# transform the encoder outputs for further attention alignments
# encoder_outputs_flat has shape [T, N, lstm_dim]
encoder_h_transformed = fc(
'encoder_h_transform',
tf.reshape(encoder_outputs, [-1, lstm_dim]),
output_dim=lstm_dim)
encoder_h_transformed = tf.reshape(encoder_h_transformed,
[T, N, lstm_dim])
self.encoder_h_transformed = encoder_h_transformed
# seq_not_finished is a shape [T, N, 1] tensor,
# where seq_not_finished[t, n]
# is 1 iff sequence n is not finished at time t, and 0 otherwise
seq_not_finished = tf.less(
tf.range(T)[:, tf.newaxis, tf.newaxis],
seq_length_batch[:, tf.newaxis])
seq_not_finished = tf.cast(seq_not_finished, tf.float32)
self.seq_not_finished = seq_not_finished
def _build_decoder(self,
use_gt_layout,
gt_layout_batch,
scope='decoder',
reuse=None):
# The main difference from before is that the decoders now takes another
# input (the attention) when computing the next step
# T_max is the maximum length of decoded sequence (including <eos>)
#
# This function is for decoding only. It performs greedy search or sampling.
# the first input is <go> (its embedding vector) and the subsequent inputs
# are the outputs from previous time step
# num_vocab does not include <go>
#
# use_gt_layout is None or a bool tensor, and gt_layout_batch is a tensor
# with shape [T_max, N].
# If use_gt_layout is not None, then when use_gt_layout is true, predict
# exactly the tokens in gt_layout_batch, regardless of actual probability.
# Otherwise, if sampling is True, sample from the token probability
# If sampling is False, do greedy decoding (beam size 1)
N = self.N
encoder_states = self.encoder_states
T_max = self.T_decoder
lstm_dim = self.lstm_dim
num_layers = self.num_layers
EOS_token = self.EOS_token
sampling = self.decoder_sampling
with tf.variable_scope(scope, reuse=reuse):
embedding_mat = tf.get_variable(
'embedding_mat', [self.decoder_num_vocab, self.decoder_embed_dim])
# we use a separate embedding for <go>, as it is only used in the
# beginning of the sequence
go_embedding = tf.get_variable('go_embedding',
[1, self.decoder_embed_dim])
with tf.variable_scope('att_prediction'):
v = tf.get_variable('v', [lstm_dim])
W_a = tf.get_variable(
'weights', [lstm_dim, lstm_dim],
initializer=tf.contrib.layers.xavier_initializer())
b_a = tf.get_variable(
'biases', lstm_dim, initializer=tf.constant_initializer(0.))
# The parameters to predict the next token
with tf.variable_scope('token_prediction'):
W_y = tf.get_variable(
'weights', [lstm_dim * 2, self.decoder_num_vocab],
initializer=tf.contrib.layers.xavier_initializer())
b_y = tf.get_variable(
'biases',
self.decoder_num_vocab,
initializer=tf.constant_initializer(0.))
# Attentional decoding
# Loop function is called at time t BEFORE the cell execution at time t,
# and its next_input is used as the input at time t (not t+1)
# c.f. https://www.tensorflow.org/api_docs/python/tf/nn/raw_rnn
mask_range = tf.reshape(
tf.range(self.decoder_num_vocab, dtype=tf.int32), [1, -1])
all_eos_pred = EOS_token * tf.ones([N], tf.int32)
all_one_prob = tf.ones([N], tf.float32)
all_zero_entropy = tf.zeros([N], tf.float32)
if use_gt_layout is not None:
gt_layout_mult = tf.cast(use_gt_layout, tf.int32)
pred_layout_mult = 1 - gt_layout_mult
def loop_fn(time, cell_output, cell_state, loop_state):
if cell_output is None: # time == 0
next_cell_state = encoder_states
next_input = tf.tile(go_embedding, [N, 1])
else: # time > 0
next_cell_state = cell_state
# compute the attention map over the input sequence
# a_raw has shape [T, N, 1]
att_raw = tf.reduce_sum(
tf.tanh(
tf.nn.xw_plus_b(cell_output, W_a, b_a) +
self.encoder_h_transformed) * v,
axis=2,
keep_dims=True)
# softmax along the first dimension (T) over not finished examples
# att has shape [T, N, 1]
att = tf.nn.softmax(att_raw, dim=0) * self.seq_not_finished
att = att / tf.reduce_sum(att, axis=0, keep_dims=True)
# d has shape [N, lstm_dim]
d2 = tf.reduce_sum(att * self.encoder_outputs, axis=0)
# token_scores has shape [N, num_vocab]
token_scores = tf.nn.xw_plus_b(
tf.concat([cell_output, d2], axis=1), W_y, b_y)
# predict the next token (behavior depending on parameters)
if sampling:
# predicted_token has shape [N]
logits = token_scores
predicted_token = tf.cast(
tf.reshape(tf.multinomial(token_scores, 1), [-1]), tf.int32)
else:
# predicted_token has shape [N]
predicted_token = tf.cast(tf.argmax(token_scores, 1), tf.int32)
if use_gt_layout is not None:
predicted_token = (gt_layout_batch[time - 1] * gt_layout_mult +
predicted_token * pred_layout_mult)
# token_prob has shape [N], the probability of the predicted token
# although token_prob is not needed for predicting the next token
# it is needed in output (for policy gradient training)
# [N, num_vocab]
# mask has shape [N, num_vocab]
mask = tf.equal(mask_range, tf.reshape(predicted_token, [-1, 1]))
all_token_probs = tf.nn.softmax(token_scores)
token_prob = tf.reduce_sum(
all_token_probs * tf.cast(mask, tf.float32), axis=1)
neg_entropy = tf.reduce_sum(
all_token_probs * tf.log(all_token_probs), axis=1)
# is_eos_predicted is a [N] bool tensor, indicating whether
# <eos> has already been predicted previously in each sequence
is_eos_predicted = loop_state[2]
predicted_token_old = predicted_token
# if <eos> has already been predicted, now predict <eos> with
# prob 1
predicted_token = tf.where(is_eos_predicted, all_eos_pred,
predicted_token)
token_prob = tf.where(is_eos_predicted, all_one_prob, token_prob)
neg_entropy = tf.where(is_eos_predicted, all_zero_entropy,
neg_entropy)
is_eos_predicted = tf.logical_or(is_eos_predicted,
tf.equal(predicted_token_old,
EOS_token))
# the prediction is from the cell output of the last step
# timestep (t-1), feed it as input into timestep t
next_input = tf.nn.embedding_lookup(embedding_mat, predicted_token)
elements_finished = tf.greater_equal(time, T_max)
# loop_state is a 5-tuple, representing
# 1) the predicted_tokens
# 2) the prob of predicted_tokens
# 3) whether <eos> has already been predicted
# 4) the negative entropy of policy (accumulated across timesteps)
# 5) the attention
if loop_state is None: # time == 0
# Write the predicted token into the output
predicted_token_array = tf.TensorArray(
dtype=tf.int32, size=T_max, infer_shape=False)
token_prob_array = tf.TensorArray(
dtype=tf.float32, size=T_max, infer_shape=False)
att_array = tf.TensorArray(
dtype=tf.float32, size=T_max, infer_shape=False)
next_loop_state = (predicted_token_array, token_prob_array, tf.zeros(
[N], dtype=tf.bool), tf.zeros([N], dtype=tf.float32), att_array)
else: # time > 0
t_write = time - 1
next_loop_state = (
loop_state[0].write(t_write, predicted_token),
loop_state[1].write(t_write, token_prob),
is_eos_predicted,
loop_state[3] + neg_entropy,
loop_state[4].write(t_write, att))
return (elements_finished, next_input, next_cell_state, cell_output,
next_loop_state)
# The RNN
cell = _get_lstm_cell(num_layers, lstm_dim)
_, _, decodes_ta = tf.nn.raw_rnn(cell, loop_fn, scope='lstm')
predicted_tokens = decodes_ta[0].stack()
token_probs = decodes_ta[1].stack()
neg_entropy = decodes_ta[3]
# atts has shape [T_decoder, T_encoder, N, 1]
atts = decodes_ta[4].stack()
self.atts = atts
# word_vec has shape [T_decoder, N, D]
word_vecs = tf.reduce_sum(atts * self.embedded_input_seq, axis=1)
predicted_tokens.set_shape([None, None])
token_probs.set_shape([None, None])
neg_entropy.set_shape([None])
word_vecs.set_shape([None, None, self.encoder_embed_dim])
self.predicted_tokens = predicted_tokens
self.token_probs = token_probs
self.neg_entropy = neg_entropy
self.word_vecs = word_vecs
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from collections import namedtuple
from Queue import Queue
import re
import threading
import numpy as np
import tensorflow as tf
Data = namedtuple('Data', ['X', 'Y', 'MultiYs', 'qid'])
class SampleBuilder:
def __init__(self, config):
self.config = config
self.kb_raw = self.read_kb()
self.data_raw = self.read_raw_data()
# dictionary of entities, normal words, and relations
self.dict_all = self.gen_dict()
self.reverse_dict_all = dict(
zip(self.dict_all.values(), self.dict_all.keys()))
tf.logging.info('size of dict: %d' % len(self.dict_all))
self.kb = self.build_kb()
self.data_all = self.build_samples()
def read_kb(self):
kb_raw = []
for line in file(self.config.KB_file):
sub, rel, obj = line.strip().split('|')
kb_raw.append((sub, rel, obj))
tf.logging.info('# of KB records: %d' % len(kb_raw))
return kb_raw
def read_raw_data(self):
data = dict()
for name in self.config.data_files:
raw = []
tf.logging.info(
'Reading data file {}'.format(self.config.data_files[name]))
for line in file(self.config.data_files[name]):
question, answers = line.strip().split('\t')
question = question.replace('],', ']') # ignore ',' in the template
raw.append((question, answers))
data[name] = raw
return data
def build_kb(self):
tf.logging.info('Indexing KB...')
kb = []
for sub, rel, obj in self.kb_raw:
kb.append([self.dict_all[sub], self.dict_all[rel], self.dict_all[obj]])
return kb
def gen_dict(self):
s = set()
for sub, rel, obj in self.kb_raw:
s.add(sub)
s.add(rel)
s.add(obj)
for name in self.data_raw:
for question, answers in self.data_raw[name]:
normal = re.split('\[[^\]]+\]', question)
for phrase in normal:
for word in phrase.split():
s.add(word)
s = list(s)
d = {s[idx]: idx for idx in range(len(s))}
return d
def build_samples(self):
def map_entity_idx(text):
entities = re.findall('\[[^\]]+\]', text)
for entity in entities:
entity = entity[1:-1]
index = self.dict_all[entity]
text = text.replace('[%s]' % entity, '@%d' % index)
return text
data_all = dict()
for name in self.data_raw:
X, Y, MultiYs, qid = [], [], [], []
for i, (question, answers) in enumerate(self.data_raw[name]):
qdata, labels = [], []
question = map_entity_idx(question)
for word in question.split():
if word[0] == '@':
qdata.append(int(word[1:]))
else:
qdata.append(self.dict_all[word])
for answer in answers.split('|'):
labels.append(self.dict_all[answer])
if len(qdata) > self.config.T_encoder:
self.config.T_encoder = len(qdata)
for label in labels:
X.append(qdata)
Y.append(label)
MultiYs.append(set(labels))
qid.append(i)
data_all[name] = Data(X=X, Y=Y, MultiYs=MultiYs, qid=qid)
return data_all
def _run_prefetch(prefetch_queue, batch_loader, data, shuffle, one_pass,
config):
assert len(data.X) == len(data.Y) == len(data.MultiYs) == len(data.qid)
num_samples = len(data.X)
batch_size = config.batch_size
n_sample = 0
fetch_order = config.rng.permutation(num_samples)
while True:
sample_ids = fetch_order[n_sample:n_sample + batch_size]
batch = batch_loader.load_one_batch(sample_ids)
prefetch_queue.put(batch, block=True)
n_sample += len(sample_ids)
if n_sample >= num_samples:
if one_pass:
prefetch_queue.put(None, block=True)
n_sample = 0
if shuffle:
fetch_order = config.rng.permutation(num_samples)
class DataReader:
def __init__(self,
config,
data,
assembler,
shuffle=True,
one_pass=False,
prefetch_num=10):
self.config = config
self.data = data
self.assembler = assembler
self.batch_loader = BatchLoader(self.config,
self.data, self.assembler)
self.shuffle = shuffle
self.one_pass = one_pass
self.prefetch_queue = Queue(maxsize=prefetch_num)
self.prefetch_thread = threading.Thread(target=_run_prefetch,
args=(self.prefetch_queue,
self.batch_loader, self.data,
self.shuffle, self.one_pass,
self.config))
self.prefetch_thread.daemon = True
self.prefetch_thread.start()
def batches(self):
while True:
if self.prefetch_queue.empty():
tf.logging.warning('Waiting for data loading (IO is slow)...')
batch = self.prefetch_queue.get(block=True)
if batch is None:
assert self.one_pass
tf.logging.info('One pass finished!')
raise StopIteration()
yield batch
class BatchLoader:
def __init__(self, config,
data, assembler):
self.config = config
self.data = data
self.assembler = assembler
self.T_encoder = config.T_encoder
self.T_decoder = config.T_decoder
tf.logging.info('T_encoder: %d' % self.T_encoder)
tf.logging.info('T_decoder: %d' % self.T_decoder)
tf.logging.info('batch size: %d' % self.config.batch_size)
self.gt_layout_tokens = config.gt_layout_tokens
def load_one_batch(self, sample_ids):
actual_batch_size = len(sample_ids)
input_seq_batch = np.zeros((self.T_encoder, actual_batch_size), np.int32)
seq_len_batch = np.zeros(actual_batch_size, np.int32)
ans_label_batch = np.zeros(actual_batch_size, np.int32)
ans_set_labels_list = [None] * actual_batch_size
question_id_list = [None] * actual_batch_size
gt_layout_batch = np.zeros((self.T_decoder, actual_batch_size), np.int32)
for batch_i in range(actual_batch_size):
idx = sample_ids[batch_i]
seq_len = len(self.data.X[idx])
seq_len_batch[batch_i] = seq_len
input_seq_batch[:seq_len, batch_i] = self.data.X[idx]
ans_label_batch[batch_i] = self.data.Y[idx]
ans_set_labels_list[batch_i] = self.data.MultiYs[idx]
question_id_list[batch_i] = self.data.qid[idx]
gt_layout_batch[:, batch_i] = self.assembler.module_list2tokens(
self.gt_layout_tokens, self.T_decoder)
batch = dict(input_seq_batch=input_seq_batch,
seq_len_batch=seq_len_batch,
ans_label_batch=ans_label_batch,
gt_layout_batch=gt_layout_batch,
ans_set_labels_list=ans_set_labels_list,
question_id_list=question_id_list)
return batch
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from datetime import datetime
import json
import logging
import os
import tensorflow as tf
import tensorflow.contrib.slim as slim
def prepare_dirs_and_logger(config):
formatter = logging.Formatter('%(asctime)s:%(levelname)s::%(message)s')
logger = logging.getLogger('tensorflow')
for hdlr in logger.handlers:
logger.removeHandler(hdlr)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(tf.logging.INFO)
config.log_dir = os.path.join(config.exp_dir, config.log_dir,
config.train_tag)
config.model_dir = os.path.join(config.exp_dir, config.model_dir,
config.train_tag)
config.output_dir = os.path.join(config.exp_dir, config.output_dir,
config.train_tag)
for path in [
config.log_dir, config.model_dir, config.output_dir
]:
if not os.path.exists(path):
os.makedirs(path)
config.data_files = {
'train': os.path.join(config.data_dir, config.train_data_file),
'dev': os.path.join(config.data_dir, config.dev_data_file),
'test': os.path.join(config.data_dir, config.test_data_file)
}
return config
def get_time():
return datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
def show_all_variables():
model_vars = tf.trainable_variables()
slim.model_analyzer.analyze_vars(model_vars, print_info=True)
def save_config(config):
param_path = os.path.join(config.model_dir, 'params.json')
tf.logging.info('log dir: %s' % config.log_dir)
tf.logging.info('model dir: %s' % config.model_dir)
tf.logging.info('param path: %s' % param_path)
tf.logging.info('output dir: %s' % config.output_dir)
with open(param_path, 'w') as f:
f.write(json.dumps(config.__dict__, indent=4, sort_keys=True))
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import tensorflow as tf
def fc_layer(name,
bottom,
output_dim,
bias_term=True,
weights_initializer=None,
biases_initializer=None,
reuse=None):
# flatten bottom input
shape = bottom.get_shape().as_list()
input_dim = 1
for d in shape[1:]:
input_dim *= d
flat_bottom = tf.reshape(bottom, [-1, input_dim])
# weights and biases variables
with tf.variable_scope(name, reuse=reuse):
# initialize the variables
if weights_initializer is None:
weights_initializer = tf.contrib.layers.xavier_initializer()
if bias_term and biases_initializer is None:
biases_initializer = tf.constant_initializer(0.)
# weights has shape [input_dim, output_dim]
weights = tf.get_variable(
'weights', [input_dim, output_dim], initializer=weights_initializer)
if bias_term:
biases = tf.get_variable(
'biases', output_dim, initializer=biases_initializer)
if not reuse:
tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES,
tf.nn.l2_loss(weights))
if bias_term:
fc = tf.nn.xw_plus_b(flat_bottom, weights, biases)
else:
fc = tf.matmul(flat_bottom, weights)
return fc
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment