Commit c705568b authored by Yanping Huang's avatar Yanping Huang Committed by Neal Wu
Browse files

Add different rnn implementation modes to ptb tutorial (#2276)

parent 7e9e15ad
......@@ -36,6 +36,13 @@ py_test(
],
)
py_library(
name = "util",
srcs = ["util.py"],
srcs_version = "PY2AND3",
deps = ["//tensorflow:tensorflow_py"],
)
py_binary(
name = "ptb_word_lm",
srcs = [
......@@ -44,7 +51,8 @@ py_binary(
srcs_version = "PY2AND3",
deps = [
":reader",
"//tensorflow:tensorflow_py",
":util",
"//tensorflow:tensorflow_py,
],
)
......
......@@ -19,3 +19,4 @@ from __future__ import division
from __future__ import print_function
import reader
import util
......@@ -40,6 +40,9 @@ The hyperparameters used in the model:
- keep_prob - the probability of keeping weights in the dropout layer
- lr_decay - the decay of the learning rate for each epoch after "max_epoch"
- batch_size - the batch size
- rnn_mode - the low level implementation of lstm cell: one of CUDNN,
BASIC, or BLOCK, representing cudnn_lstm, basic_lstm, and
lstm_block_cell classes.
The data required for this example is in the data/ dir of the
PTB dataset from Tomas Mikolov's webpage:
......@@ -56,13 +59,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import inspect
import time
import numpy as np
import tensorflow as tf
import reader
import util
from tensorflow.python.client import device_lib
flags = tf.flags
logging = tf.logging
......@@ -76,8 +81,18 @@ flags.DEFINE_string("save_path", None,
"Model output directory.")
flags.DEFINE_bool("use_fp16", False,
"Train using 16-bit floats instead of 32bit floats")
flags.DEFINE_integer("num_gpus", 1,
"If larger than 1, Grappler AutoParallel optimizer "
"will create multiple training replicas with each GPU "
"running one replica.")
flags.DEFINE_string("rnn_mode", None,
"The low level implementation of lstm cell: one of CUDNN, "
"BASIC, and BLOCK, representing cudnn_lstm, basic_lstm, "
"and lstm_block_cell classes.")
FLAGS = flags.FLAGS
BASIC = "basic"
CUDNN = "cudnn"
BLOCK = "block"
def data_type():
......@@ -99,39 +114,15 @@ class PTBModel(object):
"""The PTB model."""
def __init__(self, is_training, config, input_):
self._is_training = is_training
self._input = input_
batch_size = input_.batch_size
num_steps = input_.num_steps
self._rnn_params = None
self._cell = None
self.batch_size = input_.batch_size
self.num_steps = input_.num_steps
size = config.hidden_size
vocab_size = config.vocab_size
# Slightly better results can be obtained with forget gate biases
# initialized to 1 but the hyperparameters of the model would need to be
# different than reported in the paper.
def lstm_cell():
# With the latest TensorFlow source code (as of Mar 27, 2017),
# the BasicLSTMCell will need a reuse parameter which is unfortunately not
# defined in TensorFlow 1.0. To maintain backwards compatibility, we add
# an argument check here:
if 'reuse' in inspect.getargspec(
tf.contrib.rnn.BasicLSTMCell.__init__).args:
return tf.contrib.rnn.BasicLSTMCell(
size, forget_bias=0.0, state_is_tuple=True,
reuse=tf.get_variable_scope().reuse)
else:
return tf.contrib.rnn.BasicLSTMCell(
size, forget_bias=0.0, state_is_tuple=True)
attn_cell = lstm_cell
if is_training and config.keep_prob < 1:
def attn_cell():
return tf.contrib.rnn.DropoutWrapper(
lstm_cell(), output_keep_prob=config.keep_prob)
cell = tf.contrib.rnn.MultiRNNCell(
[attn_cell() for _ in range(config.num_layers)], state_is_tuple=True)
self._initial_state = cell.zero_state(batch_size, data_type())
with tf.device("/cpu:0"):
embedding = tf.get_variable(
"embedding", [vocab_size, size], dtype=data_type())
......@@ -140,43 +131,25 @@ class PTBModel(object):
if is_training and config.keep_prob < 1:
inputs = tf.nn.dropout(inputs, config.keep_prob)
# Simplified version of models/tutorials/rnn/rnn.py's rnn().
# This builds an unrolled LSTM for tutorial purposes only.
# In general, use the rnn() or state_saving_rnn() from rnn.py.
#
# The alternative version of the code below is:
#
# inputs = tf.unstack(inputs, num=num_steps, axis=1)
# outputs, state = tf.contrib.rnn.static_rnn(
# cell, inputs, initial_state=self._initial_state)
outputs = []
state = self._initial_state
with tf.variable_scope("RNN"):
for time_step in range(num_steps):
if time_step > 0: tf.get_variable_scope().reuse_variables()
(cell_output, state) = cell(inputs[:, time_step, :], state)
outputs.append(cell_output)
output, state = self._build_rnn_graph(inputs, config, is_training)
output = tf.reshape(tf.stack(axis=1, values=outputs), [-1, size])
softmax_w = tf.get_variable(
"softmax_w", [size, vocab_size], dtype=data_type())
softmax_b = tf.get_variable("softmax_b", [vocab_size], dtype=data_type())
logits = tf.matmul(output, softmax_w) + softmax_b
logits = tf.nn.xw_plus_b(output, softmax_w, softmax_b)
# Reshape logits to be a 3-D tensor for sequence loss
logits = tf.reshape(logits, [self.batch_size, self.num_steps, vocab_size])
# Reshape logits to be 3-D tensor for sequence loss
logits = tf.reshape(logits, [batch_size, num_steps, vocab_size])
# use the contrib sequence loss and average over the batches
# Use the contrib sequence loss and average over the batches
loss = tf.contrib.seq2seq.sequence_loss(
logits,
input_.targets,
tf.ones([batch_size, num_steps], dtype=data_type()),
tf.ones([self.batch_size, self.num_steps], dtype=data_type()),
average_across_timesteps=False,
average_across_batch=True
)
average_across_batch=True)
# update the cost variables
self._cost = cost = tf.reduce_sum(loss)
# Update the cost
self._cost = tf.reduce_sum(loss)
self._final_state = state
if not is_training:
......@@ -184,7 +157,7 @@ class PTBModel(object):
self._lr = tf.Variable(0.0, trainable=False)
tvars = tf.trainable_variables()
grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars),
grads, _ = tf.clip_by_global_norm(tf.gradients(self._cost, tvars),
config.max_grad_norm)
optimizer = tf.train.GradientDescentOptimizer(self._lr)
self._train_op = optimizer.apply_gradients(
......@@ -195,9 +168,120 @@ class PTBModel(object):
tf.float32, shape=[], name="new_learning_rate")
self._lr_update = tf.assign(self._lr, self._new_lr)
def _build_rnn_graph(self, inputs, config, is_training):
if config.rnn_mode == CUDNN:
return self._build_rnn_graph_cudnn(inputs, config, is_training)
else:
return self._build_rnn_graph_lstm(inputs, config, is_training)
def _build_rnn_graph_cudnn(self, inputs, config, is_training):
"""Build the inference graph using CUDNN cell."""
inputs = tf.transpose(inputs, [1, 0, 2])
self._cell = tf.contrib.cudnn_rnn.CudnnLSTM(
num_layers=config.num_layers,
num_units=config.hidden_size,
input_size=config.hidden_size,
dropout=1 - config.keep_prob if is_training else 0)
params_size_t = self._cell.params_size()
self._rnn_params = tf.get_variable(
"lstm_params",
initializer=tf.random_uniform(
[params_size_t], -config.init_scale, config.init_scale),
validate_shape=False)
c = tf.zeros([config.num_layers, self.batch_size, config.hidden_size],
tf.float32)
h = tf.zeros([config.num_layers, self.batch_size, config.hidden_size],
tf.float32)
self._initial_state = (tf.contrib.rnn.LSTMStateTuple(h=h, c=c),)
outputs, h, c = self._cell(inputs, h, c, self._rnn_params, is_training)
outputs = tf.transpose(outputs, [1, 0, 2])
outputs = tf.reshape(outputs, [-1, config.hidden_size])
return outputs, (tf.contrib.rnn.LSTMStateTuple(h=h, c=c),)
def _get_lstm_cell(self, config, is_training):
if config.rnn_mode == BASIC:
return tf.contrib.rnn.BasicLSTMCell(
config.hidden_size, forget_bias=0.0, state_is_tuple=True,
reuse=not is_training)
if config.rnn_mode == BLOCK:
return tf.contrib.rnn.LSTMBlockCell(
config.hidden_size, forget_bias=0.0)
raise ValueError("rnn_mode %s not supported" % config.rnn_mode)
def _build_rnn_graph_lstm(self, inputs, config, is_training):
"""Build the inference graph using canonical LSTM cells."""
# Slightly better results can be obtained with forget gate biases
# initialized to 1 but the hyperparameters of the model would need to be
# different than reported in the paper.
cell = self._get_lstm_cell(config, is_training)
if is_training and config.keep_prob < 1:
cell = tf.contrib.rnn.DropoutWrapper(
cell, output_keep_prob=config.keep_prob)
cell = tf.contrib.rnn.MultiRNNCell(
[cell for _ in range(config.num_layers)], state_is_tuple=True)
self._initial_state = cell.zero_state(config.batch_size, data_type())
state = self._initial_state
# Simplified version of tensorflow_models/tutorials/rnn/rnn.py's rnn().
# This builds an unrolled LSTM for tutorial purposes only.
# In general, use the rnn() or state_saving_rnn() from rnn.py.
#
# The alternative version of the code below is:
#
# inputs = tf.unstack(inputs, num=num_steps, axis=1)
# outputs, state = tf.contrib.rnn.static_rnn(cell, inputs,
# initial_state=self._initial_state)
outputs = []
with tf.variable_scope("RNN"):
for time_step in range(self.num_steps):
if time_step > 0: tf.get_variable_scope().reuse_variables()
(cell_output, state) = cell(inputs[:, time_step, :], state)
outputs.append(cell_output)
output = tf.reshape(tf.concat(outputs, 1), [-1, config.hidden_size])
return output, state
def assign_lr(self, session, lr_value):
session.run(self._lr_update, feed_dict={self._new_lr: lr_value})
def export_ops(self, name):
"""Exports ops to collections."""
self._name = name
ops = {util.with_prefix(self._name, "cost"): self._cost}
if self._is_training:
ops.update(lr=self._lr, new_lr=self._new_lr, lr_update=self._lr_update)
if self._rnn_params:
ops.update(rnn_params=self._rnn_params)
for name, op in ops.iteritems():
tf.add_to_collection(name, op)
self._initial_state_name = util.with_prefix(self._name, "initial")
self._final_state_name = util.with_prefix(self._name, "final")
util.export_state_tuples(self._initial_state, self._initial_state_name)
util.export_state_tuples(self._final_state, self._final_state_name)
def import_ops(self):
"""Imports ops from collections."""
if self._is_training:
self._train_op = tf.get_collection_ref("train_op")[0]
self._lr = tf.get_collection_ref("lr")[0]
self._new_lr = tf.get_collection_ref("new_lr")[0]
self._lr_update = tf.get_collection_ref("lr_update")[0]
rnn_params = tf.get_collection_ref("rnn_params")
if self._cell and rnn_params:
params_saveable = tf.contrib.cudnn_rnn.RNNParamsSaveable(
self._cell,
self._cell.params_to_canonical,
self._cell.canonical_to_params,
rnn_params,
base_variable_scope="Model/RNN")
tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, params_saveable)
self._cost = tf.get_collection_ref(util.with_prefix(self._name, "cost"))[0]
num_replicas = FLAGS.num_gpus if self._name == "Train" else 1
self._initial_state = util.import_state_tuples(
self._initial_state, self._initial_state_name, num_replicas)
self._final_state = util.import_state_tuples(
self._final_state, self._final_state_name, num_replicas)
@property
def input(self):
return self._input
......@@ -222,6 +306,14 @@ class PTBModel(object):
def train_op(self):
return self._train_op
@property
def initial_state_name(self):
return self._initial_state_name
@property
def final_state_name(self):
return self._final_state_name
class SmallConfig(object):
"""Small config."""
......@@ -237,6 +329,7 @@ class SmallConfig(object):
lr_decay = 0.5
batch_size = 20
vocab_size = 10000
rnn_mode = CUDNN
class MediumConfig(object):
......@@ -253,6 +346,7 @@ class MediumConfig(object):
lr_decay = 0.8
batch_size = 20
vocab_size = 10000
rnn_mode = BLOCK
class LargeConfig(object):
......@@ -269,6 +363,7 @@ class LargeConfig(object):
lr_decay = 1 / 1.15
batch_size = 20
vocab_size = 10000
rnn_mode = BLOCK
class TestConfig(object):
......@@ -285,6 +380,7 @@ class TestConfig(object):
lr_decay = 0.5
batch_size = 20
vocab_size = 10000
rnn_mode = BLOCK
def run_epoch(session, model, eval_op=None, verbose=False):
......@@ -317,27 +413,43 @@ def run_epoch(session, model, eval_op=None, verbose=False):
if verbose and step % (model.input.epoch_size // 10) == 10:
print("%.3f perplexity: %.3f speed: %.0f wps" %
(step * 1.0 / model.input.epoch_size, np.exp(costs / iters),
iters * model.input.batch_size / (time.time() - start_time)))
iters * model.input.batch_size * max(1, FLAGS.num_gpus) /
(time.time() - start_time)))
return np.exp(costs / iters)
def get_config():
"""Get model config."""
config = None
if FLAGS.model == "small":
return SmallConfig()
config = SmallConfig()
elif FLAGS.model == "medium":
return MediumConfig()
config = MediumConfig()
elif FLAGS.model == "large":
return LargeConfig()
config = LargeConfig()
elif FLAGS.model == "test":
return TestConfig()
config = TestConfig()
else:
raise ValueError("Invalid model: %s", FLAGS.model)
if FLAGS.rnn_mode:
config.rnn_mode = FLAGS.rnn_mode
if FLAGS.num_gpus != 1 or tf.__version__ < "1.3.0" :
config.rnn_mode = BASIC
return config
def main(_):
if not FLAGS.data_path:
raise ValueError("Must set --data_path to PTB data directory")
gpus = [
x.name for x in device_lib.list_local_devices() if x.device_type == "GPU"
]
if FLAGS.num_gpus > len(gpus):
raise ValueError(
"Your machine has only %d gpus "
"which is less than the requested --num_gpus=%d."
% (len(gpus), FLAGS.num_gpus))
raw_data = reader.ptb_raw_data(FLAGS.data_path)
train_data, valid_data, test_data, _ = raw_data
......@@ -365,13 +477,31 @@ def main(_):
tf.summary.scalar("Validation Loss", mvalid.cost)
with tf.name_scope("Test"):
test_input = PTBInput(config=eval_config, data=test_data, name="TestInput")
test_input = PTBInput(
config=eval_config, data=test_data, name="TestInput")
with tf.variable_scope("Model", reuse=True, initializer=initializer):
mtest = PTBModel(is_training=False, config=eval_config,
input_=test_input)
models = {"Train": m, "Valid": mvalid, "Test": mtest}
for name, model in models.iteritems():
model.export_ops(name)
metagraph = tf.train.export_meta_graph()
if tf.__version__ < "1.1.0" and FLAGS.num_gpus > 1:
raise ValueError("num_gpus > 1 is not supported for TensorFlow versions "
"below 1.1.0")
soft_placement = False
if FLAGS.num_gpus > 1:
soft_placement = True
util.auto_parallel(metagraph, m)
with tf.Graph().as_default():
tf.train.import_meta_graph(metagraph)
for model in models.values():
model.import_ops()
sv = tf.train.Supervisor(logdir=FLAGS.save_path)
with sv.managed_session() as session:
config_proto = tf.ConfigProto(allow_soft_placement=soft_placement)
with sv.managed_session(config=config_proto) as session:
for i in range(config.max_max_epoch):
lr_decay = config.lr_decay ** max(i + 1 - config.max_epoch, 0.0)
m.assign_lr(session, config.learning_rate * lr_decay)
......
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for Grappler autoparallel optimizer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.core.framework import variable_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
FLAGS = tf.flags.FLAGS
def export_state_tuples(state_tuples, name):
for state_tuple in state_tuples:
tf.add_to_collection(name, state_tuple.c)
tf.add_to_collection(name, state_tuple.h)
def import_state_tuples(state_tuples, name, num_replicas):
restored = []
for i in range(len(state_tuples) * num_replicas):
c = tf.get_collection_ref(name)[2 * i + 0]
h = tf.get_collection_ref(name)[2 * i + 1]
restored.append(tf.contrib.rnn.LSTMStateTuple(c, h))
return tuple(restored)
def with_prefix(prefix, name):
"""Adds prefix to name."""
return "/".join((prefix, name))
def with_autoparallel_prefix(replica_id, name):
return with_prefix("AutoParallel-Replica-%d" % replica_id, name)
class UpdateCollection(object):
"""Update collection info in MetaGraphDef for AutoParallel optimizer."""
def __init__(self, metagraph, model):
self._metagraph = metagraph
self.replicate_states(model.initial_state_name)
self.replicate_states(model.final_state_name)
self.update_snapshot_name("variables")
self.update_snapshot_name("trainable_variables")
def update_snapshot_name(self, var_coll_name):
var_list = self._metagraph.collection_def[var_coll_name]
for i, value in enumerate(var_list.bytes_list.value):
var_def = variable_pb2.VariableDef()
var_def.ParseFromString(value)
# Somehow node Model/global_step/read doesn't have any fanout and seems to
# be only used for snapshot; this is different from all other variables.
if var_def.snapshot_name != "Model/global_step/read:0":
var_def.snapshot_name = with_autoparallel_prefix(
0, var_def.snapshot_name)
value = var_def.SerializeToString()
var_list.bytes_list.value[i] = value
def replicate_states(self, state_coll_name):
state_list = self._metagraph.collection_def[state_coll_name]
num_states = len(state_list.node_list.value)
for replica_id in range(1, FLAGS.num_gpus):
for i in range(num_states):
state_list.node_list.value.append(state_list.node_list.value[i])
for replica_id in range(FLAGS.num_gpus):
for i in range(num_states):
index = replica_id * num_states + i
state_list.node_list.value[index] = with_autoparallel_prefix(
replica_id, state_list.node_list.value[index])
def auto_parallel(metagraph, model):
from google3.third_party.tensorflow.python.grappler import tf_optimizer
rewriter_config = rewriter_config_pb2.RewriterConfig()
rewriter_config.optimizers.append("autoparallel")
rewriter_config.auto_parallel.enable = True
rewriter_config.auto_parallel.num_replicas = FLAGS.num_gpus
optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, metagraph)
metagraph.graph_def.CopyFrom(optimized_graph)
UpdateCollection(metagraph, model)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment