Commit 1bfe902a authored by Martin Wicke's avatar Martin Wicke Committed by GitHub
Browse files

Merge pull request #1073 from knathanieltucker/update-namignizer

Updated Namignizer
parents 45b5353f d6787c0f
...@@ -37,11 +37,14 @@ class NamignizerModel(object): ...@@ -37,11 +37,14 @@ class NamignizerModel(object):
self._weights = tf.placeholder(tf.float32, [batch_size * num_steps]) self._weights = tf.placeholder(tf.float32, [batch_size * num_steps])
# lstm for our RNN cell (GRU supported too) # lstm for our RNN cell (GRU supported too)
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(size, forget_bias=0.0) lstm_cells = []
for layer in range(config.num_layers):
lstm_cell = tf.contrib.rnn.BasicLSTMCell(size, forget_bias=0.0)
if is_training and config.keep_prob < 1: if is_training and config.keep_prob < 1:
lstm_cell = tf.nn.rnn_cell.DropoutWrapper( lstm_cell = tf.contrib.rnn.DropoutWrapper(
lstm_cell, output_keep_prob=config.keep_prob) lstm_cell, output_keep_prob=config.keep_prob)
cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell] * config.num_layers) lstm_cells.append(lstm_cell)
cell = tf.contrib.rnn.MultiRNNCell(lstm_cells)
self._initial_state = cell.zero_state(batch_size, tf.float32) self._initial_state = cell.zero_state(batch_size, tf.float32)
...@@ -61,11 +64,11 @@ class NamignizerModel(object): ...@@ -61,11 +64,11 @@ class NamignizerModel(object):
(cell_output, state) = cell(inputs[:, time_step, :], state) (cell_output, state) = cell(inputs[:, time_step, :], state)
outputs.append(cell_output) outputs.append(cell_output)
output = tf.reshape(tf.concat(1, outputs), [-1, size]) output = tf.reshape(tf.concat(outputs, 1), [-1, size])
softmax_w = tf.get_variable("softmax_w", [size, vocab_size]) softmax_w = tf.get_variable("softmax_w", [size, vocab_size])
softmax_b = tf.get_variable("softmax_b", [vocab_size]) softmax_b = tf.get_variable("softmax_b", [vocab_size])
logits = tf.matmul(output, softmax_w) + softmax_b logits = tf.matmul(output, softmax_w) + softmax_b
loss = tf.nn.seq2seq.sequence_loss_by_example( loss = tf.contrib.legacy_seq2seq.sequence_loss_by_example(
[logits], [logits],
[tf.reshape(self._targets, [-1])], [tf.reshape(self._targets, [-1])],
[self._weights]) [self._weights])
...@@ -77,7 +80,7 @@ class NamignizerModel(object): ...@@ -77,7 +80,7 @@ class NamignizerModel(object):
self._activations = tf.nn.softmax(logits) self._activations = tf.nn.softmax(logits)
# ability to save the model # ability to save the model
self.saver = tf.train.Saver(tf.all_variables()) self.saver = tf.train.Saver(tf.global_variables())
if not is_training: if not is_training:
return return
......
...@@ -122,7 +122,6 @@ def run_epoch(session, m, names, counts, epoch_size, eval_op, verbose=False): ...@@ -122,7 +122,6 @@ def run_epoch(session, m, names, counts, epoch_size, eval_op, verbose=False):
cost, _ = session.run([m.cost, eval_op], cost, _ = session.run([m.cost, eval_op],
{m.input_data: x, {m.input_data: x,
m.targets: y, m.targets: y,
m.initial_state: m.initial_state.eval(),
m.weights: np.ones(m.batch_size * m.num_steps)}) m.weights: np.ones(m.batch_size * m.num_steps)})
costs += cost costs += cost
iters += m.num_steps iters += m.num_steps
...@@ -201,7 +200,6 @@ def namignize(names, checkpoint_path, config): ...@@ -201,7 +200,6 @@ def namignize(names, checkpoint_path, config):
cost, loss, _ = session.run([m.cost, m.loss, tf.no_op()], cost, loss, _ = session.run([m.cost, m.loss, tf.no_op()],
{m.input_data: x, {m.input_data: x,
m.targets: y, m.targets: y,
m.initial_state: m.initial_state.eval(),
m.weights: np.concatenate(( m.weights: np.concatenate((
np.ones(len(name)), np.zeros(m.batch_size * m.num_steps - len(name))))}) np.ones(len(name)), np.zeros(m.batch_size * m.num_steps - len(name))))})
...@@ -234,7 +232,6 @@ def namignator(checkpoint_path, config): ...@@ -234,7 +232,6 @@ def namignator(checkpoint_path, config):
activations, final_state, _ = session.run([m.activations, m.final_state, tf.no_op()], activations, final_state, _ = session.run([m.activations, m.final_state, tf.no_op()],
{m.input_data: np.zeros((1, 1)), {m.input_data: np.zeros((1, 1)),
m.targets: np.zeros((1, 1)), m.targets: np.zeros((1, 1)),
m.initial_state: m.initial_state.eval(),
m.weights: np.ones(1)}) m.weights: np.ones(1)})
# sample from our softmax activations # sample from our softmax activations
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment