"python/vscode:/vscode.git/clone" did not exist on "976bc302e52b12d1d2e581cc5d8a952ac1c6b0a4"
Unverified Commit 1f3bbc02 authored by Neal Wu's avatar Neal Wu Committed by GitHub
Browse files

Merge pull request #2795 from martinkersner/deprecated_get_or_create_global_step

Replace deprecated get_or_create_global_step
parents 01aa7a4a a2b2088c
......@@ -116,7 +116,7 @@ class VatxtModel(object):
"""
def __init__(self, cl_logits_input_dim=None):
self.global_step = tf.contrib.framework.get_or_create_global_step()
self.global_step = tf.train.get_or_create_global_step()
self.vocab_freqs = _get_vocab_freqs()
# Cache VatxtInput objects
......
......@@ -137,7 +137,7 @@ class Model(object):
self.memory = self.get_memory()
self.classifier = self.get_classifier()
self.global_step = tf.contrib.framework.get_or_create_global_step()
self.global_step = tf.train.get_or_create_global_step()
def get_embedder(self):
return LeNet(int(self.input_dim ** 0.5), 1, self.rep_dim)
......
......@@ -364,7 +364,7 @@ class Trainer(object):
if FLAGS.supervisor:
with tf.device(tf.ReplicaDeviceSetter(FLAGS.ps_tasks, merge_devices=True)):
self.global_step = tf.contrib.framework.get_or_create_global_step()
self.global_step = tf.train.get_or_create_global_step()
tf.set_random_seed(FLAGS.tf_seed)
self.controller = self.get_controller()
self.model = self.controller.model
......@@ -382,7 +382,7 @@ class Trainer(object):
sess = sv.PrepareSession(FLAGS.master)
else:
tf.set_random_seed(FLAGS.tf_seed)
self.global_step = tf.contrib.framework.get_or_create_global_step()
self.global_step = tf.train.get_or_create_global_step()
self.controller = self.get_controller()
self.model = self.controller.model
self.controller.setup()
......
......@@ -56,7 +56,7 @@ class ResNet(object):
def build_graph(self):
"""Build a whole graph for the model."""
self.global_step = tf.contrib.framework.get_or_create_global_step()
self.global_step = tf.train.get_or_create_global_step()
self._build_model()
if self.mode == 'train':
self._build_train_op()
......
......@@ -411,8 +411,9 @@ class NasNetABaseCell(object):
tf.summary.scalar('layer_ratio', layer_ratio)
drop_path_keep_prob = 1 - layer_ratio * (1 - drop_path_keep_prob)
# Decrease the keep probability over time
current_step = tf.cast(tf.contrib.framework.get_or_create_global_step(),
current_step = tf.cast(tf.train.get_or_create_global_step(),
tf.float32)
print("HERE")
drop_path_burn_in_steps = self._total_training_steps
current_ratio = (
current_step / drop_path_burn_in_steps)
......
......@@ -61,7 +61,7 @@ parser.add_argument('--log_frequency', type=int, default=10,
def train():
"""Train CIFAR-10 for a number of steps."""
with tf.Graph().as_default():
global_step = tf.contrib.framework.get_or_create_global_step()
global_step = tf.train.get_or_create_global_step()
# Get images and labels for CIFAR-10.
# Force input pipeline to CPU:0 to avoid operations sometimes ending up on
......
......@@ -162,7 +162,7 @@ class PTBModel(object):
optimizer = tf.train.GradientDescentOptimizer(self._lr)
self._train_op = optimizer.apply_gradients(
zip(grads, tvars),
global_step=tf.contrib.framework.get_or_create_global_step())
global_step=tf.train.get_or_create_global_step())
self._new_lr = tf.placeholder(
tf.float32, shape=[], name="new_learning_rate")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment