Commit c45213c9 authored by jianchao-li's avatar jianchao-li Committed by Mark Daoust
Browse files

Use distutils.version.StrictVersion for version comparisons

parent a989673b
...@@ -22,12 +22,14 @@ import sys ...@@ -22,12 +22,14 @@ import sys
import six.moves.urllib.request as request import six.moves.urllib.request as request
from distutils.version import StrictVersion
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
# Check that we have correct TensorFlow version installed # Check that we have correct TensorFlow version installed
tf_version = tf.__version__ tf_version = tf.__version__
tf.logging.info("TensorFlow version: {}".format(tf_version)) tf.logging.info("TensorFlow version: {}".format(tf_version))
assert "1.4" <= tf_version, "TensorFlow r1.4 or later is needed" assert StrictVersion("1.4") <= StrictVersion(tf_version), "TensorFlow r1.4 or later is needed"
# Windows users: You only need to change PATH, rest is platform independent # Windows users: You only need to change PATH, rest is platform independent
PATH = "/tmp/tf_custom_estimators" PATH = "/tmp/tf_custom_estimators"
......
...@@ -22,10 +22,12 @@ import os ...@@ -22,10 +22,12 @@ import os
import six.moves.urllib.request as request import six.moves.urllib.request as request
import tensorflow as tf import tensorflow as tf
from distutils.version import StrictVersion
# Check that we have correct TensorFlow version installed # Check that we have correct TensorFlow version installed
tf_version = tf.__version__ tf_version = tf.__version__
print("TensorFlow version: {}".format(tf_version)) print("TensorFlow version: {}".format(tf_version))
assert "1.4" <= tf_version, "TensorFlow r1.4 or later is needed" assert StrictVersion("1.4") <= StrictVersion(tf_version), "TensorFlow r1.4 or later is needed"
# Windows users: You only need to change PATH, rest is platform independent # Windows users: You only need to change PATH, rest is platform independent
PATH = "/tmp/tf_dataset_and_estimator_apis" PATH = "/tmp/tf_dataset_and_estimator_apis"
......
...@@ -69,6 +69,8 @@ import util ...@@ -69,6 +69,8 @@ import util
from tensorflow.python.client import device_lib from tensorflow.python.client import device_lib
from distutils.version import StrictVersion
flags = tf.flags flags = tf.flags
logging = tf.logging logging = tf.logging
...@@ -436,7 +438,7 @@ def get_config(): ...@@ -436,7 +438,7 @@ def get_config():
raise ValueError("Invalid model: %s", FLAGS.model) raise ValueError("Invalid model: %s", FLAGS.model)
if FLAGS.rnn_mode: if FLAGS.rnn_mode:
config.rnn_mode = FLAGS.rnn_mode config.rnn_mode = FLAGS.rnn_mode
if FLAGS.num_gpus != 1 or tf.__version__ < "1.3.0" : if FLAGS.num_gpus != 1 or StrictVersion(tf.__version__) < StrictVersion("1.3.0") :
config.rnn_mode = BASIC config.rnn_mode = BASIC
return config return config
...@@ -489,7 +491,7 @@ def main(_): ...@@ -489,7 +491,7 @@ def main(_):
for name, model in models.items(): for name, model in models.items():
model.export_ops(name) model.export_ops(name)
metagraph = tf.train.export_meta_graph() metagraph = tf.train.export_meta_graph()
if tf.__version__ < "1.1.0" and FLAGS.num_gpus > 1: if StrictVersion(tf.__version__) < StrictVersion("1.1.0") and FLAGS.num_gpus > 1:
raise ValueError("num_gpus > 1 is not supported for TensorFlow versions " raise ValueError("num_gpus > 1 is not supported for TensorFlow versions "
"below 1.1.0") "below 1.1.0")
soft_placement = False soft_placement = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment