Commit c15fada2 authored by Neal Wu's avatar Neal Wu
Browse files

Rewrite to use inspect.getargspec

parent 167b6c69
...@@ -56,6 +56,7 @@ from __future__ import absolute_import ...@@ -56,6 +56,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import inspect
import time import time
import numpy as np import numpy as np
...@@ -111,13 +112,14 @@ class PTBModel(object): ...@@ -111,13 +112,14 @@ class PTBModel(object):
def lstm_cell(): def lstm_cell():
# With the latest TensorFlow source code (as of Mar 27, 2017), # With the latest TensorFlow source code (as of Mar 27, 2017),
# the BasicLSTMCell will need a reuse parameter which is unfortunately not # the BasicLSTMCell will need a reuse parameter which is unfortunately not
# defined in TensorFlow 1.0. To maintain backwards compatibility, we add a # defined in TensorFlow 1.0. To maintain backwards compatibility, we add
# try-except here: # an argument check here:
try: if 'reuse' in inspect.getargspec(
tf.contrib.rnn.BasicLSTMCell.__init__).args:
return tf.contrib.rnn.BasicLSTMCell( return tf.contrib.rnn.BasicLSTMCell(
size, forget_bias=0.0, state_is_tuple=True, size, forget_bias=0.0, state_is_tuple=True,
reuse=tf.get_variable_scope().reuse) reuse=tf.get_variable_scope().reuse)
except TypeError: else:
return tf.contrib.rnn.BasicLSTMCell( return tf.contrib.rnn.BasicLSTMCell(
size, forget_bias=0.0, state_is_tuple=True) size, forget_bias=0.0, state_is_tuple=True)
attn_cell = lstm_cell attn_cell = lstm_cell
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment