Commit c45213c9 authored by jianchao-li's avatar jianchao-li Committed by Mark Daoust
Browse files

Use distutils.version.StrictVersion for version comparisons

parent a989673b
......@@ -22,12 +22,14 @@ import sys
import six.moves.urllib.request as request
from distutils.version import StrictVersion
tf.logging.set_verbosity(tf.logging.INFO)
# Check that we have correct TensorFlow version installed
tf_version = tf.__version__
tf.logging.info("TensorFlow version: {}".format(tf_version))
assert "1.4" <= tf_version, "TensorFlow r1.4 or later is needed"
assert StrictVersion("1.4") <= StrictVersion(tf_version), "TensorFlow r1.4 or later is needed"
# Windows users: You only need to change PATH, rest is platform independent
PATH = "/tmp/tf_custom_estimators"
......
......@@ -22,10 +22,12 @@ import os
import six.moves.urllib.request as request
import tensorflow as tf
from distutils.version import StrictVersion
# Check that we have correct TensorFlow version installed
tf_version = tf.__version__
print("TensorFlow version: {}".format(tf_version))
assert "1.4" <= tf_version, "TensorFlow r1.4 or later is needed"
assert StrictVersion("1.4") <= StrictVersion(tf_version), "TensorFlow r1.4 or later is needed"
# Windows users: You only need to change PATH, rest is platform independent
PATH = "/tmp/tf_dataset_and_estimator_apis"
......
......@@ -69,6 +69,8 @@ import util
from tensorflow.python.client import device_lib
from distutils.version import StrictVersion
flags = tf.flags
logging = tf.logging
......@@ -436,7 +438,7 @@ def get_config():
raise ValueError("Invalid model: %s", FLAGS.model)
if FLAGS.rnn_mode:
config.rnn_mode = FLAGS.rnn_mode
if FLAGS.num_gpus != 1 or tf.__version__ < "1.3.0" :
if FLAGS.num_gpus != 1 or StrictVersion(tf.__version__) < StrictVersion("1.3.0") :
config.rnn_mode = BASIC
return config
......@@ -489,7 +491,7 @@ def main(_):
for name, model in models.items():
model.export_ops(name)
metagraph = tf.train.export_meta_graph()
if tf.__version__ < "1.1.0" and FLAGS.num_gpus > 1:
if StrictVersion(tf.__version__) < StrictVersion("1.1.0") and FLAGS.num_gpus > 1:
raise ValueError("num_gpus > 1 is not supported for TensorFlow versions "
"below 1.1.0")
soft_placement = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment