Commit c705568b authored by Yanping Huang's avatar Yanping Huang Committed by Neal Wu
Browse files

Add different rnn implementation modes to ptb tutorial (#2276)

parent 7e9e15ad
...@@ -36,6 +36,13 @@ py_test( ...@@ -36,6 +36,13 @@ py_test(
], ],
) )
py_library(
name = "util",
srcs = ["util.py"],
srcs_version = "PY2AND3",
deps = ["//tensorflow:tensorflow_py"],
)
py_binary( py_binary(
name = "ptb_word_lm", name = "ptb_word_lm",
srcs = [ srcs = [
...@@ -44,7 +51,8 @@ py_binary( ...@@ -44,7 +51,8 @@ py_binary(
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":reader", ":reader",
"//tensorflow:tensorflow_py", ":util",
"//tensorflow:tensorflow_py,
], ],
) )
......
...@@ -19,3 +19,4 @@ from __future__ import division ...@@ -19,3 +19,4 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import reader import reader
import util
...@@ -40,6 +40,9 @@ The hyperparameters used in the model: ...@@ -40,6 +40,9 @@ The hyperparameters used in the model:
- keep_prob - the probability of keeping weights in the dropout layer - keep_prob - the probability of keeping weights in the dropout layer
- lr_decay - the decay of the learning rate for each epoch after "max_epoch" - lr_decay - the decay of the learning rate for each epoch after "max_epoch"
- batch_size - the batch size - batch_size - the batch size
- rnn_mode - the low level implementation of lstm cell: one of CUDNN,
BASIC, or BLOCK, representing cudnn_lstm, basic_lstm, and
lstm_block_cell classes.
The data required for this example is in the data/ dir of the The data required for this example is in the data/ dir of the
PTB dataset from Tomas Mikolov's webpage: PTB dataset from Tomas Mikolov's webpage:
...@@ -56,13 +59,15 @@ from __future__ import absolute_import ...@@ -56,13 +59,15 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import inspect
import time import time
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
import reader import reader
import util
from tensorflow.python.client import device_lib
flags = tf.flags flags = tf.flags
logging = tf.logging logging = tf.logging
...@@ -76,8 +81,18 @@ flags.DEFINE_string("save_path", None, ...@@ -76,8 +81,18 @@ flags.DEFINE_string("save_path", None,
"Model output directory.") "Model output directory.")
flags.DEFINE_bool("use_fp16", False, flags.DEFINE_bool("use_fp16", False,
"Train using 16-bit floats instead of 32bit floats") "Train using 16-bit floats instead of 32bit floats")
flags.DEFINE_integer("num_gpus", 1,
"If larger than 1, Grappler AutoParallel optimizer "
"will create multiple training replicas with each GPU "
"running one replica.")
flags.DEFINE_string("rnn_mode", None,
"The low level implementation of lstm cell: one of CUDNN, "
"BASIC, and BLOCK, representing cudnn_lstm, basic_lstm, "
"and lstm_block_cell classes.")
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
BASIC = "basic"
CUDNN = "cudnn"
BLOCK = "block"
def data_type(): def data_type():
...@@ -99,39 +114,15 @@ class PTBModel(object): ...@@ -99,39 +114,15 @@ class PTBModel(object):
"""The PTB model.""" """The PTB model."""
def __init__(self, is_training, config, input_): def __init__(self, is_training, config, input_):
self._is_training = is_training
self._input = input_ self._input = input_
self._rnn_params = None
batch_size = input_.batch_size self._cell = None
num_steps = input_.num_steps self.batch_size = input_.batch_size
self.num_steps = input_.num_steps
size = config.hidden_size size = config.hidden_size
vocab_size = config.vocab_size vocab_size = config.vocab_size
# Slightly better results can be obtained with forget gate biases
# initialized to 1 but the hyperparameters of the model would need to be
# different than reported in the paper.
def lstm_cell():
# With the latest TensorFlow source code (as of Mar 27, 2017),
# the BasicLSTMCell will need a reuse parameter which is unfortunately not
# defined in TensorFlow 1.0. To maintain backwards compatibility, we add
# an argument check here:
if 'reuse' in inspect.getargspec(
tf.contrib.rnn.BasicLSTMCell.__init__).args:
return tf.contrib.rnn.BasicLSTMCell(
size, forget_bias=0.0, state_is_tuple=True,
reuse=tf.get_variable_scope().reuse)
else:
return tf.contrib.rnn.BasicLSTMCell(
size, forget_bias=0.0, state_is_tuple=True)
attn_cell = lstm_cell
if is_training and config.keep_prob < 1:
def attn_cell():
return tf.contrib.rnn.DropoutWrapper(
lstm_cell(), output_keep_prob=config.keep_prob)
cell = tf.contrib.rnn.MultiRNNCell(
[attn_cell() for _ in range(config.num_layers)], state_is_tuple=True)
self._initial_state = cell.zero_state(batch_size, data_type())
with tf.device("/cpu:0"): with tf.device("/cpu:0"):
embedding = tf.get_variable( embedding = tf.get_variable(
"embedding", [vocab_size, size], dtype=data_type()) "embedding", [vocab_size, size], dtype=data_type())
...@@ -140,43 +131,25 @@ class PTBModel(object): ...@@ -140,43 +131,25 @@ class PTBModel(object):
if is_training and config.keep_prob < 1: if is_training and config.keep_prob < 1:
inputs = tf.nn.dropout(inputs, config.keep_prob) inputs = tf.nn.dropout(inputs, config.keep_prob)
# Simplified version of models/tutorials/rnn/rnn.py's rnn(). output, state = self._build_rnn_graph(inputs, config, is_training)
# This builds an unrolled LSTM for tutorial purposes only.
# In general, use the rnn() or state_saving_rnn() from rnn.py.
#
# The alternative version of the code below is:
#
# inputs = tf.unstack(inputs, num=num_steps, axis=1)
# outputs, state = tf.contrib.rnn.static_rnn(
# cell, inputs, initial_state=self._initial_state)
outputs = []
state = self._initial_state
with tf.variable_scope("RNN"):
for time_step in range(num_steps):
if time_step > 0: tf.get_variable_scope().reuse_variables()
(cell_output, state) = cell(inputs[:, time_step, :], state)
outputs.append(cell_output)
output = tf.reshape(tf.stack(axis=1, values=outputs), [-1, size])
softmax_w = tf.get_variable( softmax_w = tf.get_variable(
"softmax_w", [size, vocab_size], dtype=data_type()) "softmax_w", [size, vocab_size], dtype=data_type())
softmax_b = tf.get_variable("softmax_b", [vocab_size], dtype=data_type()) softmax_b = tf.get_variable("softmax_b", [vocab_size], dtype=data_type())
logits = tf.matmul(output, softmax_w) + softmax_b logits = tf.nn.xw_plus_b(output, softmax_w, softmax_b)
# Reshape logits to be a 3-D tensor for sequence loss
logits = tf.reshape(logits, [self.batch_size, self.num_steps, vocab_size])
# Reshape logits to be 3-D tensor for sequence loss # Use the contrib sequence loss and average over the batches
logits = tf.reshape(logits, [batch_size, num_steps, vocab_size])
# use the contrib sequence loss and average over the batches
loss = tf.contrib.seq2seq.sequence_loss( loss = tf.contrib.seq2seq.sequence_loss(
logits, logits,
input_.targets, input_.targets,
tf.ones([batch_size, num_steps], dtype=data_type()), tf.ones([self.batch_size, self.num_steps], dtype=data_type()),
average_across_timesteps=False, average_across_timesteps=False,
average_across_batch=True average_across_batch=True)
)
# update the cost variables # Update the cost
self._cost = cost = tf.reduce_sum(loss) self._cost = tf.reduce_sum(loss)
self._final_state = state self._final_state = state
if not is_training: if not is_training:
...@@ -184,7 +157,7 @@ class PTBModel(object): ...@@ -184,7 +157,7 @@ class PTBModel(object):
self._lr = tf.Variable(0.0, trainable=False) self._lr = tf.Variable(0.0, trainable=False)
tvars = tf.trainable_variables() tvars = tf.trainable_variables()
grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars), grads, _ = tf.clip_by_global_norm(tf.gradients(self._cost, tvars),
config.max_grad_norm) config.max_grad_norm)
optimizer = tf.train.GradientDescentOptimizer(self._lr) optimizer = tf.train.GradientDescentOptimizer(self._lr)
self._train_op = optimizer.apply_gradients( self._train_op = optimizer.apply_gradients(
...@@ -195,9 +168,120 @@ class PTBModel(object): ...@@ -195,9 +168,120 @@ class PTBModel(object):
tf.float32, shape=[], name="new_learning_rate") tf.float32, shape=[], name="new_learning_rate")
self._lr_update = tf.assign(self._lr, self._new_lr) self._lr_update = tf.assign(self._lr, self._new_lr)
def _build_rnn_graph(self, inputs, config, is_training):
if config.rnn_mode == CUDNN:
return self._build_rnn_graph_cudnn(inputs, config, is_training)
else:
return self._build_rnn_graph_lstm(inputs, config, is_training)
def _build_rnn_graph_cudnn(self, inputs, config, is_training):
"""Build the inference graph using CUDNN cell."""
inputs = tf.transpose(inputs, [1, 0, 2])
self._cell = tf.contrib.cudnn_rnn.CudnnLSTM(
num_layers=config.num_layers,
num_units=config.hidden_size,
input_size=config.hidden_size,
dropout=1 - config.keep_prob if is_training else 0)
params_size_t = self._cell.params_size()
self._rnn_params = tf.get_variable(
"lstm_params",
initializer=tf.random_uniform(
[params_size_t], -config.init_scale, config.init_scale),
validate_shape=False)
c = tf.zeros([config.num_layers, self.batch_size, config.hidden_size],
tf.float32)
h = tf.zeros([config.num_layers, self.batch_size, config.hidden_size],
tf.float32)
self._initial_state = (tf.contrib.rnn.LSTMStateTuple(h=h, c=c),)
outputs, h, c = self._cell(inputs, h, c, self._rnn_params, is_training)
outputs = tf.transpose(outputs, [1, 0, 2])
outputs = tf.reshape(outputs, [-1, config.hidden_size])
return outputs, (tf.contrib.rnn.LSTMStateTuple(h=h, c=c),)
def _get_lstm_cell(self, config, is_training):
if config.rnn_mode == BASIC:
return tf.contrib.rnn.BasicLSTMCell(
config.hidden_size, forget_bias=0.0, state_is_tuple=True,
reuse=not is_training)
if config.rnn_mode == BLOCK:
return tf.contrib.rnn.LSTMBlockCell(
config.hidden_size, forget_bias=0.0)
raise ValueError("rnn_mode %s not supported" % config.rnn_mode)
def _build_rnn_graph_lstm(self, inputs, config, is_training):
"""Build the inference graph using canonical LSTM cells."""
# Slightly better results can be obtained with forget gate biases
# initialized to 1 but the hyperparameters of the model would need to be
# different than reported in the paper.
cell = self._get_lstm_cell(config, is_training)
if is_training and config.keep_prob < 1:
cell = tf.contrib.rnn.DropoutWrapper(
cell, output_keep_prob=config.keep_prob)
cell = tf.contrib.rnn.MultiRNNCell(
[cell for _ in range(config.num_layers)], state_is_tuple=True)
self._initial_state = cell.zero_state(config.batch_size, data_type())
state = self._initial_state
# Simplified version of tensorflow_models/tutorials/rnn/rnn.py's rnn().
# This builds an unrolled LSTM for tutorial purposes only.
# In general, use the rnn() or state_saving_rnn() from rnn.py.
#
# The alternative version of the code below is:
#
# inputs = tf.unstack(inputs, num=num_steps, axis=1)
# outputs, state = tf.contrib.rnn.static_rnn(cell, inputs,
# initial_state=self._initial_state)
outputs = []
with tf.variable_scope("RNN"):
for time_step in range(self.num_steps):
if time_step > 0: tf.get_variable_scope().reuse_variables()
(cell_output, state) = cell(inputs[:, time_step, :], state)
outputs.append(cell_output)
output = tf.reshape(tf.concat(outputs, 1), [-1, config.hidden_size])
return output, state
def assign_lr(self, session, lr_value): def assign_lr(self, session, lr_value):
session.run(self._lr_update, feed_dict={self._new_lr: lr_value}) session.run(self._lr_update, feed_dict={self._new_lr: lr_value})
def export_ops(self, name):
"""Exports ops to collections."""
self._name = name
ops = {util.with_prefix(self._name, "cost"): self._cost}
if self._is_training:
ops.update(lr=self._lr, new_lr=self._new_lr, lr_update=self._lr_update)
if self._rnn_params:
ops.update(rnn_params=self._rnn_params)
for name, op in ops.iteritems():
tf.add_to_collection(name, op)
self._initial_state_name = util.with_prefix(self._name, "initial")
self._final_state_name = util.with_prefix(self._name, "final")
util.export_state_tuples(self._initial_state, self._initial_state_name)
util.export_state_tuples(self._final_state, self._final_state_name)
def import_ops(self):
"""Imports ops from collections."""
if self._is_training:
self._train_op = tf.get_collection_ref("train_op")[0]
self._lr = tf.get_collection_ref("lr")[0]
self._new_lr = tf.get_collection_ref("new_lr")[0]
self._lr_update = tf.get_collection_ref("lr_update")[0]
rnn_params = tf.get_collection_ref("rnn_params")
if self._cell and rnn_params:
params_saveable = tf.contrib.cudnn_rnn.RNNParamsSaveable(
self._cell,
self._cell.params_to_canonical,
self._cell.canonical_to_params,
rnn_params,
base_variable_scope="Model/RNN")
tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, params_saveable)
self._cost = tf.get_collection_ref(util.with_prefix(self._name, "cost"))[0]
num_replicas = FLAGS.num_gpus if self._name == "Train" else 1
self._initial_state = util.import_state_tuples(
self._initial_state, self._initial_state_name, num_replicas)
self._final_state = util.import_state_tuples(
self._final_state, self._final_state_name, num_replicas)
@property @property
def input(self): def input(self):
return self._input return self._input
...@@ -222,6 +306,14 @@ class PTBModel(object): ...@@ -222,6 +306,14 @@ class PTBModel(object):
def train_op(self): def train_op(self):
return self._train_op return self._train_op
@property
def initial_state_name(self):
return self._initial_state_name
@property
def final_state_name(self):
return self._final_state_name
class SmallConfig(object): class SmallConfig(object):
"""Small config.""" """Small config."""
...@@ -237,6 +329,7 @@ class SmallConfig(object): ...@@ -237,6 +329,7 @@ class SmallConfig(object):
lr_decay = 0.5 lr_decay = 0.5
batch_size = 20 batch_size = 20
vocab_size = 10000 vocab_size = 10000
rnn_mode = CUDNN
class MediumConfig(object): class MediumConfig(object):
...@@ -253,6 +346,7 @@ class MediumConfig(object): ...@@ -253,6 +346,7 @@ class MediumConfig(object):
lr_decay = 0.8 lr_decay = 0.8
batch_size = 20 batch_size = 20
vocab_size = 10000 vocab_size = 10000
rnn_mode = BLOCK
class LargeConfig(object): class LargeConfig(object):
...@@ -269,6 +363,7 @@ class LargeConfig(object): ...@@ -269,6 +363,7 @@ class LargeConfig(object):
lr_decay = 1 / 1.15 lr_decay = 1 / 1.15
batch_size = 20 batch_size = 20
vocab_size = 10000 vocab_size = 10000
rnn_mode = BLOCK
class TestConfig(object): class TestConfig(object):
...@@ -285,6 +380,7 @@ class TestConfig(object): ...@@ -285,6 +380,7 @@ class TestConfig(object):
lr_decay = 0.5 lr_decay = 0.5
batch_size = 20 batch_size = 20
vocab_size = 10000 vocab_size = 10000
rnn_mode = BLOCK
def run_epoch(session, model, eval_op=None, verbose=False): def run_epoch(session, model, eval_op=None, verbose=False):
...@@ -317,27 +413,43 @@ def run_epoch(session, model, eval_op=None, verbose=False): ...@@ -317,27 +413,43 @@ def run_epoch(session, model, eval_op=None, verbose=False):
if verbose and step % (model.input.epoch_size // 10) == 10: if verbose and step % (model.input.epoch_size // 10) == 10:
print("%.3f perplexity: %.3f speed: %.0f wps" % print("%.3f perplexity: %.3f speed: %.0f wps" %
(step * 1.0 / model.input.epoch_size, np.exp(costs / iters), (step * 1.0 / model.input.epoch_size, np.exp(costs / iters),
iters * model.input.batch_size / (time.time() - start_time))) iters * model.input.batch_size * max(1, FLAGS.num_gpus) /
(time.time() - start_time)))
return np.exp(costs / iters) return np.exp(costs / iters)
def get_config(): def get_config():
"""Get model config."""
config = None
if FLAGS.model == "small": if FLAGS.model == "small":
return SmallConfig() config = SmallConfig()
elif FLAGS.model == "medium": elif FLAGS.model == "medium":
return MediumConfig() config = MediumConfig()
elif FLAGS.model == "large": elif FLAGS.model == "large":
return LargeConfig() config = LargeConfig()
elif FLAGS.model == "test": elif FLAGS.model == "test":
return TestConfig() config = TestConfig()
else: else:
raise ValueError("Invalid model: %s", FLAGS.model) raise ValueError("Invalid model: %s", FLAGS.model)
if FLAGS.rnn_mode:
config.rnn_mode = FLAGS.rnn_mode
if FLAGS.num_gpus != 1 or tf.__version__ < "1.3.0" :
config.rnn_mode = BASIC
return config
def main(_): def main(_):
if not FLAGS.data_path: if not FLAGS.data_path:
raise ValueError("Must set --data_path to PTB data directory") raise ValueError("Must set --data_path to PTB data directory")
gpus = [
x.name for x in device_lib.list_local_devices() if x.device_type == "GPU"
]
if FLAGS.num_gpus > len(gpus):
raise ValueError(
"Your machine has only %d gpus "
"which is less than the requested --num_gpus=%d."
% (len(gpus), FLAGS.num_gpus))
raw_data = reader.ptb_raw_data(FLAGS.data_path) raw_data = reader.ptb_raw_data(FLAGS.data_path)
train_data, valid_data, test_data, _ = raw_data train_data, valid_data, test_data, _ = raw_data
...@@ -365,13 +477,31 @@ def main(_): ...@@ -365,13 +477,31 @@ def main(_):
tf.summary.scalar("Validation Loss", mvalid.cost) tf.summary.scalar("Validation Loss", mvalid.cost)
with tf.name_scope("Test"): with tf.name_scope("Test"):
test_input = PTBInput(config=eval_config, data=test_data, name="TestInput") test_input = PTBInput(
config=eval_config, data=test_data, name="TestInput")
with tf.variable_scope("Model", reuse=True, initializer=initializer): with tf.variable_scope("Model", reuse=True, initializer=initializer):
mtest = PTBModel(is_training=False, config=eval_config, mtest = PTBModel(is_training=False, config=eval_config,
input_=test_input) input_=test_input)
models = {"Train": m, "Valid": mvalid, "Test": mtest}
for name, model in models.iteritems():
model.export_ops(name)
metagraph = tf.train.export_meta_graph()
if tf.__version__ < "1.1.0" and FLAGS.num_gpus > 1:
raise ValueError("num_gpus > 1 is not supported for TensorFlow versions "
"below 1.1.0")
soft_placement = False
if FLAGS.num_gpus > 1:
soft_placement = True
util.auto_parallel(metagraph, m)
with tf.Graph().as_default():
tf.train.import_meta_graph(metagraph)
for model in models.values():
model.import_ops()
sv = tf.train.Supervisor(logdir=FLAGS.save_path) sv = tf.train.Supervisor(logdir=FLAGS.save_path)
with sv.managed_session() as session: config_proto = tf.ConfigProto(allow_soft_placement=soft_placement)
with sv.managed_session(config=config_proto) as session:
for i in range(config.max_max_epoch): for i in range(config.max_max_epoch):
lr_decay = config.lr_decay ** max(i + 1 - config.max_epoch, 0.0) lr_decay = config.lr_decay ** max(i + 1 - config.max_epoch, 0.0)
m.assign_lr(session, config.learning_rate * lr_decay) m.assign_lr(session, config.learning_rate * lr_decay)
......
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for Grappler autoparallel optimizer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.core.framework import variable_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
FLAGS = tf.flags.FLAGS
def export_state_tuples(state_tuples, name):
for state_tuple in state_tuples:
tf.add_to_collection(name, state_tuple.c)
tf.add_to_collection(name, state_tuple.h)
def import_state_tuples(state_tuples, name, num_replicas):
restored = []
for i in range(len(state_tuples) * num_replicas):
c = tf.get_collection_ref(name)[2 * i + 0]
h = tf.get_collection_ref(name)[2 * i + 1]
restored.append(tf.contrib.rnn.LSTMStateTuple(c, h))
return tuple(restored)
def with_prefix(prefix, name):
"""Adds prefix to name."""
return "/".join((prefix, name))
def with_autoparallel_prefix(replica_id, name):
return with_prefix("AutoParallel-Replica-%d" % replica_id, name)
class UpdateCollection(object):
"""Update collection info in MetaGraphDef for AutoParallel optimizer."""
def __init__(self, metagraph, model):
self._metagraph = metagraph
self.replicate_states(model.initial_state_name)
self.replicate_states(model.final_state_name)
self.update_snapshot_name("variables")
self.update_snapshot_name("trainable_variables")
def update_snapshot_name(self, var_coll_name):
var_list = self._metagraph.collection_def[var_coll_name]
for i, value in enumerate(var_list.bytes_list.value):
var_def = variable_pb2.VariableDef()
var_def.ParseFromString(value)
# Somehow node Model/global_step/read doesn't have any fanout and seems to
# be only used for snapshot; this is different from all other variables.
if var_def.snapshot_name != "Model/global_step/read:0":
var_def.snapshot_name = with_autoparallel_prefix(
0, var_def.snapshot_name)
value = var_def.SerializeToString()
var_list.bytes_list.value[i] = value
def replicate_states(self, state_coll_name):
state_list = self._metagraph.collection_def[state_coll_name]
num_states = len(state_list.node_list.value)
for replica_id in range(1, FLAGS.num_gpus):
for i in range(num_states):
state_list.node_list.value.append(state_list.node_list.value[i])
for replica_id in range(FLAGS.num_gpus):
for i in range(num_states):
index = replica_id * num_states + i
state_list.node_list.value[index] = with_autoparallel_prefix(
replica_id, state_list.node_list.value[index])
def auto_parallel(metagraph, model):
from google3.third_party.tensorflow.python.grappler import tf_optimizer
rewriter_config = rewriter_config_pb2.RewriterConfig()
rewriter_config.optimizers.append("autoparallel")
rewriter_config.auto_parallel.enable = True
rewriter_config.auto_parallel.num_replicas = FLAGS.num_gpus
optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, metagraph)
metagraph.graph_def.CopyFrom(optimized_graph)
UpdateCollection(metagraph, model)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment