Commit e00e0e13 authored by dreamdragon's avatar dreamdragon
Browse files

Merge remote-tracking branch 'upstream/master'

parents b915db4e 402b561b
......@@ -21,7 +21,7 @@ from __future__ import print_function
from absl.testing import absltest
import numpy as np
from light_curve_util import median_filter
from light_curve import median_filter
class MedianFilterTest(absltest.TestCase):
......
......@@ -20,7 +20,7 @@ from __future__ import print_function
from absl.testing import absltest
from light_curve_util.periodic_event import Event
from light_curve.periodic_event import Event
class EventTest(absltest.TestCase):
......
......@@ -21,9 +21,9 @@ from __future__ import print_function
from absl.testing import absltest
import numpy as np
from light_curve_util import periodic_event
from light_curve import periodic_event
from light_curve_util import util
from light_curve import util
class LightCurveUtilTest(absltest.TestCase):
......@@ -89,13 +89,13 @@ class LightCurveUtilTest(absltest.TestCase):
]
all_flux = [np.ones(25), np.ones(10)]
self.assertEqual(len(all_time), 2)
self.assertEqual(len(all_time[0]), 25)
self.assertEqual(len(all_time[1]), 10)
self.assertLen(all_time, 2)
self.assertLen(all_time[0], 25)
self.assertLen(all_time[1], 10)
self.assertEqual(len(all_flux), 2)
self.assertEqual(len(all_flux[0]), 25)
self.assertEqual(len(all_flux[1]), 10)
self.assertLen(all_flux, 2)
self.assertLen(all_flux[0], 25)
self.assertLen(all_flux[1], 10)
# Gap width 0.5.
split_time, split_flux = util.split(all_time, all_flux, gap_width=0.5)
......@@ -268,7 +268,7 @@ class LightCurveUtilTest(absltest.TestCase):
np.array([80, 90]),
]
reshard_xs = util.reshard_arrays(xs, ys)
self.assertEqual(5, len(reshard_xs))
self.assertLen(reshard_xs, 5)
np.testing.assert_array_equal([], reshard_xs[0])
np.testing.assert_array_equal([1, 2], reshard_xs[1])
np.testing.assert_array_equal([3, 4, 5, 6], reshard_xs[2])
......
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
py_library(
name = "configdict",
srcs = ["configdict.py"],
srcs_version = "PY2AND3",
deps = [
],
)
py_test(
name = "configdict_test",
size = "small",
srcs = ["configdict_test.py"],
srcs_version = "PY2AND3",
deps = [":configdict"],
)
py_library(
name = "config_util",
srcs = ["config_util.py"],
srcs_version = "PY2AND3",
)
py_test(
name = "config_util_test",
size = "small",
srcs = ["config_util_test.py"],
srcs_version = "PY2AND3",
deps = [":config_util"],
)
py_library(
name = "estimator_runner",
srcs = ["estimator_runner.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "example_util",
srcs = ["example_util.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
)
py_test(
name = "example_util_test",
size = "small",
srcs = ["example_util_test.py"],
srcs_version = "PY2AND3",
deps = [":example_util"],
)
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
......@@ -20,7 +20,7 @@ from __future__ import print_function
import tensorflow as tf
from astronet.util import config_util
from tf_util import config_util
class ConfigUtilTest(tf.test.TestCase):
......
......@@ -20,7 +20,7 @@ from __future__ import print_function
from absl.testing import absltest
from astronet.util import configdict
from tf_util import configdict
class ConfigDictTest(absltest.TestCase):
......
......@@ -45,6 +45,8 @@ def evaluate(estimator, eval_args):
latest_checkpoint = estimator.latest_checkpoint()
if not latest_checkpoint:
# This is expected if the training job has not yet saved a checkpoint.
tf.logging.info("No checkpoint in %s, skipping evaluation.",
estimator.model_dir)
return global_step, values
tf.logging.info("Starting evaluation on checkpoint %s", latest_checkpoint)
......@@ -54,7 +56,7 @@ def evaluate(estimator, eval_args):
input_fn, steps=eval_steps, name=eval_name)
if global_step is None:
global_step = values[eval_name].get("global_step")
except (tf.errors.NotFoundError, ValueError):
except tf.errors.NotFoundError:
# Expected under some conditions, e.g. checkpoint is already deleted by the
# trainer process. Increasing RunConfig.keep_checkpoint_max may prevent this
# in some cases.
......
......@@ -21,7 +21,7 @@ from __future__ import print_function
import numpy as np
import tensorflow as tf
from astronet.util import example_util
from tf_util import example_util
class ExampleUtilTest(tf.test.TestCase):
......
......@@ -49,14 +49,15 @@ VGGish depends on the following Python packages:
* [`resampy`](http://resampy.readthedocs.io/en/latest/)
* [`tensorflow`](http://www.tensorflow.org/)
* [`six`](https://pythonhosted.org/six/)
* [`pysoundfile`](https://pysoundfile.readthedocs.io/)
These are all easily installable via, e.g., `pip install numpy` (as in the
example command sequence below).
Any reasonably recent version of these packages should work. TensorFlow should
be at least version 1.0. We have tested with Python 2.7.6 and 3.4.3 on an
Ubuntu-like system with NumPy v1.13.1, SciPy v0.19.1, resampy v0.1.5, TensorFlow
v1.2.1, and Six v1.10.0.
be at least version 1.0. We have tested that everything works on Ubuntu and
Windows 10 with Python 3.6.6, Numpy v1.15.4, SciPy v1.1.0, resampy v0.2.1,
TensorFlow v1.3.0, Six v1.11.0 and PySoundFile 0.9.0.
VGGish also requires downloading two data files:
......
......@@ -17,11 +17,12 @@
import numpy as np
import resampy
from scipy.io import wavfile
import mel_features
import vggish_params
import soundfile as sf
def waveform_to_examples(data, sample_rate):
"""Converts audio waveform into an array of examples for VGGish.
......@@ -80,7 +81,7 @@ def wavfile_to_examples(wav_file):
Returns:
See waveform_to_examples.
"""
sr, wav_data = wavfile.read(wav_file)
wav_data, sr = sf.read(wav_file, dtype='int16')
assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype
samples = wav_data / 32768.0 # Convert to [-1.0, +1.0]
return waveform_to_examples(samples, sr)
......@@ -126,9 +126,10 @@ def main(unused_argv):
eval_scales=FLAGS.inference_scales,
add_flipped_images=FLAGS.add_flipped_images)
predictions = tf.cast(predictions[common.OUTPUT_TYPE], tf.float32)
# Crop the valid regions from the predictions.
semantic_predictions = tf.slice(
predictions[common.OUTPUT_TYPE],
predictions,
[0, 0, 0],
[1, resized_image_size[0], resized_image_size[1]])
# Resize back the prediction to the original image size.
......@@ -140,7 +141,7 @@ def main(unused_argv):
label_size,
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
align_corners=True)
return tf.squeeze(resized_label, 3)
return tf.cast(tf.squeeze(resized_label, 3), tf.int32)
semantic_predictions = _resize_label(semantic_predictions, image_size)
semantic_predictions = tf.identity(semantic_predictions, name=_OUTPUT_NAME)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment