Commit a2b2088c authored by Martin Kersner's avatar Martin Kersner
Browse files

Replace deprecated get_or_create_global_step

tf.contrib.framework.get_or_create_global_step -> tf.train.get_or_create_global_step
parent 4364390a
...@@ -116,7 +116,7 @@ class VatxtModel(object): ...@@ -116,7 +116,7 @@ class VatxtModel(object):
""" """
def __init__(self, cl_logits_input_dim=None): def __init__(self, cl_logits_input_dim=None):
self.global_step = tf.contrib.framework.get_or_create_global_step() self.global_step = tf.train.get_or_create_global_step()
self.vocab_freqs = _get_vocab_freqs() self.vocab_freqs = _get_vocab_freqs()
# Cache VatxtInput objects # Cache VatxtInput objects
......
...@@ -137,7 +137,7 @@ class Model(object): ...@@ -137,7 +137,7 @@ class Model(object):
self.memory = self.get_memory() self.memory = self.get_memory()
self.classifier = self.get_classifier() self.classifier = self.get_classifier()
self.global_step = tf.contrib.framework.get_or_create_global_step() self.global_step = tf.train.get_or_create_global_step()
def get_embedder(self): def get_embedder(self):
return LeNet(int(self.input_dim ** 0.5), 1, self.rep_dim) return LeNet(int(self.input_dim ** 0.5), 1, self.rep_dim)
......
...@@ -364,7 +364,7 @@ class Trainer(object): ...@@ -364,7 +364,7 @@ class Trainer(object):
if FLAGS.supervisor: if FLAGS.supervisor:
with tf.device(tf.ReplicaDeviceSetter(FLAGS.ps_tasks, merge_devices=True)): with tf.device(tf.ReplicaDeviceSetter(FLAGS.ps_tasks, merge_devices=True)):
self.global_step = tf.contrib.framework.get_or_create_global_step() self.global_step = tf.train.get_or_create_global_step()
tf.set_random_seed(FLAGS.tf_seed) tf.set_random_seed(FLAGS.tf_seed)
self.controller = self.get_controller() self.controller = self.get_controller()
self.model = self.controller.model self.model = self.controller.model
...@@ -382,7 +382,7 @@ class Trainer(object): ...@@ -382,7 +382,7 @@ class Trainer(object):
sess = sv.PrepareSession(FLAGS.master) sess = sv.PrepareSession(FLAGS.master)
else: else:
tf.set_random_seed(FLAGS.tf_seed) tf.set_random_seed(FLAGS.tf_seed)
self.global_step = tf.contrib.framework.get_or_create_global_step() self.global_step = tf.train.get_or_create_global_step()
self.controller = self.get_controller() self.controller = self.get_controller()
self.model = self.controller.model self.model = self.controller.model
self.controller.setup() self.controller.setup()
......
...@@ -56,7 +56,7 @@ class ResNet(object): ...@@ -56,7 +56,7 @@ class ResNet(object):
def build_graph(self): def build_graph(self):
"""Build a whole graph for the model.""" """Build a whole graph for the model."""
self.global_step = tf.contrib.framework.get_or_create_global_step() self.global_step = tf.train.get_or_create_global_step()
self._build_model() self._build_model()
if self.mode == 'train': if self.mode == 'train':
self._build_train_op() self._build_train_op()
......
...@@ -411,8 +411,9 @@ class NasNetABaseCell(object): ...@@ -411,8 +411,9 @@ class NasNetABaseCell(object):
tf.summary.scalar('layer_ratio', layer_ratio) tf.summary.scalar('layer_ratio', layer_ratio)
drop_path_keep_prob = 1 - layer_ratio * (1 - drop_path_keep_prob) drop_path_keep_prob = 1 - layer_ratio * (1 - drop_path_keep_prob)
# Decrease the keep probability over time # Decrease the keep probability over time
current_step = tf.cast(tf.contrib.framework.get_or_create_global_step(), current_step = tf.cast(tf.train.get_or_create_global_step(),
tf.float32) tf.float32)
print("HERE")
drop_path_burn_in_steps = self._total_training_steps drop_path_burn_in_steps = self._total_training_steps
current_ratio = ( current_ratio = (
current_step / drop_path_burn_in_steps) current_step / drop_path_burn_in_steps)
......
...@@ -61,7 +61,7 @@ parser.add_argument('--log_frequency', type=int, default=10, ...@@ -61,7 +61,7 @@ parser.add_argument('--log_frequency', type=int, default=10,
def train(): def train():
"""Train CIFAR-10 for a number of steps.""" """Train CIFAR-10 for a number of steps."""
with tf.Graph().as_default(): with tf.Graph().as_default():
global_step = tf.contrib.framework.get_or_create_global_step() global_step = tf.train.get_or_create_global_step()
# Get images and labels for CIFAR-10. # Get images and labels for CIFAR-10.
# Force input pipeline to CPU:0 to avoid operations sometimes ending up on # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
......
...@@ -162,7 +162,7 @@ class PTBModel(object): ...@@ -162,7 +162,7 @@ class PTBModel(object):
optimizer = tf.train.GradientDescentOptimizer(self._lr) optimizer = tf.train.GradientDescentOptimizer(self._lr)
self._train_op = optimizer.apply_gradients( self._train_op = optimizer.apply_gradients(
zip(grads, tvars), zip(grads, tvars),
global_step=tf.contrib.framework.get_or_create_global_step()) global_step=tf.train.get_or_create_global_step())
self._new_lr = tf.placeholder( self._new_lr = tf.placeholder(
tf.float32, shape=[], name="new_learning_rate") tf.float32, shape=[], name="new_learning_rate")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment