"docs/git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "a466b3c2fe22c866523afea8cffbee242f4489a5"
Commit 211ee00a authored by Neal Wu's avatar Neal Wu
Browse files

Convert tf.GraphKeys.VARIABLES -> tf.GraphKeys.GLOBAL_VARIABLES

parent 5978a4a1
...@@ -65,9 +65,9 @@ class InceptionTest(tf.test.TestCase): ...@@ -65,9 +65,9 @@ class InceptionTest(tf.test.TestCase):
inception.inception_resnet_v2(inputs, num_classes) inception.inception_resnet_v2(inputs, num_classes)
with tf.variable_scope('on_gpu'), tf.device('/gpu:0'): with tf.variable_scope('on_gpu'), tf.device('/gpu:0'):
inception.inception_resnet_v2(inputs, num_classes) inception.inception_resnet_v2(inputs, num_classes)
for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope='on_cpu'): for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='on_cpu'):
self.assertDeviceEqual(v.device, '/cpu:0') self.assertDeviceEqual(v.device, '/cpu:0')
for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope='on_gpu'): for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='on_gpu'):
self.assertDeviceEqual(v.device, '/gpu:0') self.assertDeviceEqual(v.device, '/gpu:0')
def testHalfSizeImages(self): def testHalfSizeImages(self):
......
...@@ -146,9 +146,9 @@ class InceptionTest(tf.test.TestCase): ...@@ -146,9 +146,9 @@ class InceptionTest(tf.test.TestCase):
inception.inception_v4(inputs, num_classes) inception.inception_v4(inputs, num_classes)
with tf.variable_scope('on_gpu'), tf.device('/gpu:0'): with tf.variable_scope('on_gpu'), tf.device('/gpu:0'):
inception.inception_v4(inputs, num_classes) inception.inception_v4(inputs, num_classes)
for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope='on_cpu'): for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='on_cpu'):
self.assertDeviceEqual(v.device, '/cpu:0') self.assertDeviceEqual(v.device, '/cpu:0')
for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope='on_gpu'): for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='on_gpu'):
self.assertDeviceEqual(v.device, '/gpu:0') self.assertDeviceEqual(v.device, '/gpu:0')
def testHalfSizeImages(self): def testHalfSizeImages(self):
......
...@@ -196,7 +196,7 @@ def main(unused_argv): ...@@ -196,7 +196,7 @@ def main(unused_argv):
print 'Constructing saver.' print 'Constructing saver.'
# Make saver. # Make saver.
saver = tf.train.Saver( saver = tf.train.Saver(
tf.get_collection(tf.GraphKeys.VARIABLES), max_to_keep=0) tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES), max_to_keep=0)
# Make training session. # Make training session.
sess = tf.InteractiveSession() sess = tf.InteractiveSession()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment