"docs/vscode:/vscode.git/clone" did not exist on "e5b92d2b7944fb192c93293fbb62d5c7ba3c5591"
Commit e00e0e13 authored by dreamdragon's avatar dreamdragon
Browse files

Merge remote-tracking branch 'upstream/master'

parents b915db4e 402b561b
...@@ -21,7 +21,7 @@ from __future__ import print_function ...@@ -21,7 +21,7 @@ from __future__ import print_function
from absl.testing import absltest from absl.testing import absltest
import numpy as np import numpy as np
from light_curve_util import median_filter from light_curve import median_filter
class MedianFilterTest(absltest.TestCase): class MedianFilterTest(absltest.TestCase):
......
...@@ -20,7 +20,7 @@ from __future__ import print_function ...@@ -20,7 +20,7 @@ from __future__ import print_function
from absl.testing import absltest from absl.testing import absltest
from light_curve_util.periodic_event import Event from light_curve.periodic_event import Event
class EventTest(absltest.TestCase): class EventTest(absltest.TestCase):
......
...@@ -21,9 +21,9 @@ from __future__ import print_function ...@@ -21,9 +21,9 @@ from __future__ import print_function
from absl.testing import absltest from absl.testing import absltest
import numpy as np import numpy as np
from light_curve_util import periodic_event from light_curve import periodic_event
from light_curve_util import util from light_curve import util
class LightCurveUtilTest(absltest.TestCase): class LightCurveUtilTest(absltest.TestCase):
...@@ -89,13 +89,13 @@ class LightCurveUtilTest(absltest.TestCase): ...@@ -89,13 +89,13 @@ class LightCurveUtilTest(absltest.TestCase):
] ]
all_flux = [np.ones(25), np.ones(10)] all_flux = [np.ones(25), np.ones(10)]
self.assertEqual(len(all_time), 2) self.assertLen(all_time, 2)
self.assertEqual(len(all_time[0]), 25) self.assertLen(all_time[0], 25)
self.assertEqual(len(all_time[1]), 10) self.assertLen(all_time[1], 10)
self.assertEqual(len(all_flux), 2) self.assertLen(all_flux, 2)
self.assertEqual(len(all_flux[0]), 25) self.assertLen(all_flux[0], 25)
self.assertEqual(len(all_flux[1]), 10) self.assertLen(all_flux[1], 10)
# Gap width 0.5. # Gap width 0.5.
split_time, split_flux = util.split(all_time, all_flux, gap_width=0.5) split_time, split_flux = util.split(all_time, all_flux, gap_width=0.5)
...@@ -268,7 +268,7 @@ class LightCurveUtilTest(absltest.TestCase): ...@@ -268,7 +268,7 @@ class LightCurveUtilTest(absltest.TestCase):
np.array([80, 90]), np.array([80, 90]),
] ]
reshard_xs = util.reshard_arrays(xs, ys) reshard_xs = util.reshard_arrays(xs, ys)
self.assertEqual(5, len(reshard_xs)) self.assertLen(reshard_xs, 5)
np.testing.assert_array_equal([], reshard_xs[0]) np.testing.assert_array_equal([], reshard_xs[0])
np.testing.assert_array_equal([1, 2], reshard_xs[1]) np.testing.assert_array_equal([1, 2], reshard_xs[1])
np.testing.assert_array_equal([3, 4, 5, 6], reshard_xs[2]) np.testing.assert_array_equal([3, 4, 5, 6], reshard_xs[2])
......
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
py_library(
name = "configdict",
srcs = ["configdict.py"],
srcs_version = "PY2AND3",
deps = [
],
)
py_test(
name = "configdict_test",
size = "small",
srcs = ["configdict_test.py"],
srcs_version = "PY2AND3",
deps = [":configdict"],
)
py_library(
name = "config_util",
srcs = ["config_util.py"],
srcs_version = "PY2AND3",
)
py_test(
name = "config_util_test",
size = "small",
srcs = ["config_util_test.py"],
srcs_version = "PY2AND3",
deps = [":config_util"],
)
py_library(
name = "estimator_runner",
srcs = ["estimator_runner.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "example_util",
srcs = ["example_util.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
)
py_test(
name = "example_util_test",
size = "small",
srcs = ["example_util_test.py"],
srcs_version = "PY2AND3",
deps = [":example_util"],
)
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
...@@ -20,7 +20,7 @@ from __future__ import print_function ...@@ -20,7 +20,7 @@ from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from astronet.util import config_util from tf_util import config_util
class ConfigUtilTest(tf.test.TestCase): class ConfigUtilTest(tf.test.TestCase):
......
...@@ -20,7 +20,7 @@ from __future__ import print_function ...@@ -20,7 +20,7 @@ from __future__ import print_function
from absl.testing import absltest from absl.testing import absltest
from astronet.util import configdict from tf_util import configdict
class ConfigDictTest(absltest.TestCase): class ConfigDictTest(absltest.TestCase):
......
...@@ -45,6 +45,8 @@ def evaluate(estimator, eval_args): ...@@ -45,6 +45,8 @@ def evaluate(estimator, eval_args):
latest_checkpoint = estimator.latest_checkpoint() latest_checkpoint = estimator.latest_checkpoint()
if not latest_checkpoint: if not latest_checkpoint:
# This is expected if the training job has not yet saved a checkpoint. # This is expected if the training job has not yet saved a checkpoint.
tf.logging.info("No checkpoint in %s, skipping evaluation.",
estimator.model_dir)
return global_step, values return global_step, values
tf.logging.info("Starting evaluation on checkpoint %s", latest_checkpoint) tf.logging.info("Starting evaluation on checkpoint %s", latest_checkpoint)
...@@ -54,7 +56,7 @@ def evaluate(estimator, eval_args): ...@@ -54,7 +56,7 @@ def evaluate(estimator, eval_args):
input_fn, steps=eval_steps, name=eval_name) input_fn, steps=eval_steps, name=eval_name)
if global_step is None: if global_step is None:
global_step = values[eval_name].get("global_step") global_step = values[eval_name].get("global_step")
except (tf.errors.NotFoundError, ValueError): except tf.errors.NotFoundError:
# Expected under some conditions, e.g. checkpoint is already deleted by the # Expected under some conditions, e.g. checkpoint is already deleted by the
# trainer process. Increasing RunConfig.keep_checkpoint_max may prevent this # trainer process. Increasing RunConfig.keep_checkpoint_max may prevent this
# in some cases. # in some cases.
......
...@@ -21,7 +21,7 @@ from __future__ import print_function ...@@ -21,7 +21,7 @@ from __future__ import print_function
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from astronet.util import example_util from tf_util import example_util
class ExampleUtilTest(tf.test.TestCase): class ExampleUtilTest(tf.test.TestCase):
......
...@@ -49,14 +49,15 @@ VGGish depends on the following Python packages: ...@@ -49,14 +49,15 @@ VGGish depends on the following Python packages:
* [`resampy`](http://resampy.readthedocs.io/en/latest/) * [`resampy`](http://resampy.readthedocs.io/en/latest/)
* [`tensorflow`](http://www.tensorflow.org/) * [`tensorflow`](http://www.tensorflow.org/)
* [`six`](https://pythonhosted.org/six/) * [`six`](https://pythonhosted.org/six/)
* [`pysoundfile`](https://pysoundfile.readthedocs.io/)
These are all easily installable via, e.g., `pip install numpy` (as in the These are all easily installable via, e.g., `pip install numpy` (as in the
example command sequence below). example command sequence below).
Any reasonably recent version of these packages should work. TensorFlow should Any reasonably recent version of these packages should work. TensorFlow should
be at least version 1.0. We have tested with Python 2.7.6 and 3.4.3 on an be at least version 1.0. We have tested that everything works on Ubuntu and
Ubuntu-like system with NumPy v1.13.1, SciPy v0.19.1, resampy v0.1.5, TensorFlow Windows 10 with Python 3.6.6, Numpy v1.15.4, SciPy v1.1.0, resampy v0.2.1,
v1.2.1, and Six v1.10.0. TensorFlow v1.3.0, Six v1.11.0 and PySoundFile 0.9.0.
VGGish also requires downloading two data files: VGGish also requires downloading two data files:
......
...@@ -17,11 +17,12 @@ ...@@ -17,11 +17,12 @@
import numpy as np import numpy as np
import resampy import resampy
from scipy.io import wavfile
import mel_features import mel_features
import vggish_params import vggish_params
import soundfile as sf
def waveform_to_examples(data, sample_rate): def waveform_to_examples(data, sample_rate):
"""Converts audio waveform into an array of examples for VGGish. """Converts audio waveform into an array of examples for VGGish.
...@@ -80,7 +81,7 @@ def wavfile_to_examples(wav_file): ...@@ -80,7 +81,7 @@ def wavfile_to_examples(wav_file):
Returns: Returns:
See waveform_to_examples. See waveform_to_examples.
""" """
sr, wav_data = wavfile.read(wav_file) wav_data, sr = sf.read(wav_file, dtype='int16')
assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype
samples = wav_data / 32768.0 # Convert to [-1.0, +1.0] samples = wav_data / 32768.0 # Convert to [-1.0, +1.0]
return waveform_to_examples(samples, sr) return waveform_to_examples(samples, sr)
...@@ -126,9 +126,10 @@ def main(unused_argv): ...@@ -126,9 +126,10 @@ def main(unused_argv):
eval_scales=FLAGS.inference_scales, eval_scales=FLAGS.inference_scales,
add_flipped_images=FLAGS.add_flipped_images) add_flipped_images=FLAGS.add_flipped_images)
predictions = tf.cast(predictions[common.OUTPUT_TYPE], tf.float32)
# Crop the valid regions from the predictions. # Crop the valid regions from the predictions.
semantic_predictions = tf.slice( semantic_predictions = tf.slice(
predictions[common.OUTPUT_TYPE], predictions,
[0, 0, 0], [0, 0, 0],
[1, resized_image_size[0], resized_image_size[1]]) [1, resized_image_size[0], resized_image_size[1]])
# Resize back the prediction to the original image size. # Resize back the prediction to the original image size.
...@@ -140,7 +141,7 @@ def main(unused_argv): ...@@ -140,7 +141,7 @@ def main(unused_argv):
label_size, label_size,
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
align_corners=True) align_corners=True)
return tf.squeeze(resized_label, 3) return tf.cast(tf.squeeze(resized_label, 3), tf.int32)
semantic_predictions = _resize_label(semantic_predictions, image_size) semantic_predictions = _resize_label(semantic_predictions, image_size)
semantic_predictions = tf.identity(semantic_predictions, name=_OUTPUT_NAME) semantic_predictions = tf.identity(semantic_predictions, name=_OUTPUT_NAME)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment