Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
42da7864
Commit
42da7864
authored
Dec 04, 2018
by
Christopher Shallue
Browse files
Move tensorflow_models/research/astronet to google-research/exoplanet-ml
parent
17c2f0cc
Changes
130
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
3248 deletions
+0
-3248
research/astronet/astronet/ops/dataset_ops_test.py
research/astronet/astronet/ops/dataset_ops_test.py
+0
-638
research/astronet/astronet/ops/input_ops.py
research/astronet/astronet/ops/input_ops.py
+0
-88
research/astronet/astronet/ops/input_ops_test.py
research/astronet/astronet/ops/input_ops_test.py
+0
-135
research/astronet/astronet/ops/metrics.py
research/astronet/astronet/ops/metrics.py
+0
-153
research/astronet/astronet/ops/metrics_test.py
research/astronet/astronet/ops/metrics_test.py
+0
-280
research/astronet/astronet/ops/test_data/test_dataset.tfrecord
...rch/astronet/astronet/ops/test_data/test_dataset.tfrecord
+0
-0
research/astronet/astronet/ops/testing.py
research/astronet/astronet/ops/testing.py
+0
-71
research/astronet/astronet/ops/training.py
research/astronet/astronet/ops/training.py
+0
-106
research/astronet/astronet/predict.py
research/astronet/astronet/predict.py
+0
-173
research/astronet/astronet/train.py
research/astronet/astronet/train.py
+0
-133
research/astronet/astronet/util/BUILD
research/astronet/astronet/util/BUILD
+0
-14
research/astronet/astronet/util/__init__.py
research/astronet/astronet/util/__init__.py
+0
-14
research/astronet/astronet/util/estimator_util.py
research/astronet/astronet/util/estimator_util.py
+0
-252
research/astronet/astrowavenet/BUILD
research/astronet/astrowavenet/BUILD
+0
-49
research/astronet/astrowavenet/README.md
research/astronet/astrowavenet/README.md
+0
-41
research/astronet/astrowavenet/__init__.py
research/astronet/astrowavenet/__init__.py
+0
-14
research/astronet/astrowavenet/astrowavenet_model.py
research/astronet/astrowavenet/astrowavenet_model.py
+0
-321
research/astronet/astrowavenet/astrowavenet_model_test.py
research/astronet/astrowavenet/astrowavenet_model_test.py
+0
-631
research/astronet/astrowavenet/configurations.py
research/astronet/astrowavenet/configurations.py
+0
-76
research/astronet/astrowavenet/data/BUILD
research/astronet/astrowavenet/data/BUILD
+0
-59
No files found.
research/astronet/astronet/ops/dataset_ops_test.py
deleted
100644 → 0
View file @
17c2f0cc
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for dataset_ops.py."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os.path
from
absl
import
flags
import
numpy
as
np
import
tensorflow
as
tf
from
astronet.ops
import
dataset_ops
from
tf_util
import
configdict
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_string
(
"test_srcdir"
,
""
,
"Test source directory."
)
_TEST_TFRECORD_FILE
=
"astronet/ops/test_data/test_dataset.tfrecord"
class
DatasetOpsTest
(
tf
.
test
.
TestCase
):
def
testPadTensorToBatchSize
(
self
):
with
self
.
test_session
():
# Cannot pad a 0-dimensional Tensor.
tensor_0d
=
tf
.
constant
(
1
)
with
self
.
assertRaises
(
ValueError
):
dataset_ops
.
pad_tensor_to_batch_size
(
tensor_0d
,
10
)
# 1-dimensional Tensor. Un-padded batch size is 5.
tensor_1d
=
tf
.
range
(
5
,
dtype
=
tf
.
int32
)
self
.
assertEqual
([
5
],
tensor_1d
.
shape
)
self
.
assertAllEqual
([
0
,
1
,
2
,
3
,
4
],
tensor_1d
.
eval
())
tensor_1d_pad5
=
dataset_ops
.
pad_tensor_to_batch_size
(
tensor_1d
,
5
)
self
.
assertEqual
([
5
],
tensor_1d_pad5
.
shape
)
self
.
assertAllEqual
([
0
,
1
,
2
,
3
,
4
],
tensor_1d_pad5
.
eval
())
tensor_1d_pad8
=
dataset_ops
.
pad_tensor_to_batch_size
(
tensor_1d
,
8
)
self
.
assertEqual
([
8
],
tensor_1d_pad8
.
shape
)
self
.
assertAllEqual
([
0
,
1
,
2
,
3
,
4
,
0
,
0
,
0
],
tensor_1d_pad8
.
eval
())
# 2-dimensional Tensor. Un-padded batch size is 3.
tensor_2d
=
tf
.
reshape
(
tf
.
range
(
9
,
dtype
=
tf
.
int32
),
[
3
,
3
])
self
.
assertEqual
([
3
,
3
],
tensor_2d
.
shape
)
self
.
assertAllEqual
([[
0
,
1
,
2
],
[
3
,
4
,
5
],
[
6
,
7
,
8
]],
tensor_2d
.
eval
())
tensor_2d_pad3
=
dataset_ops
.
pad_tensor_to_batch_size
(
tensor_2d
,
3
)
self
.
assertEqual
([
3
,
3
],
tensor_2d_pad3
.
shape
)
self
.
assertAllEqual
([[
0
,
1
,
2
],
[
3
,
4
,
5
],
[
6
,
7
,
8
]],
tensor_2d_pad3
.
eval
())
tensor_2d_pad4
=
dataset_ops
.
pad_tensor_to_batch_size
(
tensor_2d
,
4
)
self
.
assertEqual
([
4
,
3
],
tensor_2d_pad4
.
shape
)
self
.
assertAllEqual
([[
0
,
1
,
2
],
[
3
,
4
,
5
],
[
6
,
7
,
8
],
[
0
,
0
,
0
]],
tensor_2d_pad4
.
eval
())
def
testPadDatasetToBatchSizeNoWeights
(
self
):
values
=
{
"labels"
:
np
.
arange
(
10
,
dtype
=
np
.
int32
)}
dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
values
).
batch
(
4
)
self
.
assertItemsEqual
([
"labels"
],
dataset
.
output_shapes
.
keys
())
self
.
assertFalse
(
dataset
.
output_shapes
[
"labels"
].
is_fully_defined
())
dataset_pad
=
dataset_ops
.
pad_dataset_to_batch_size
(
dataset
,
4
)
self
.
assertItemsEqual
([
"labels"
,
"weights"
],
dataset_pad
.
output_shapes
.
keys
())
self
.
assertEqual
([
4
],
dataset_pad
.
output_shapes
[
"labels"
])
self
.
assertEqual
([
4
],
dataset_pad
.
output_shapes
[
"weights"
])
next_batch
=
dataset_pad
.
make_one_shot_iterator
().
get_next
()
next_labels
=
next_batch
[
"labels"
]
next_weights
=
next_batch
[
"weights"
]
with
self
.
test_session
()
as
sess
:
labels
,
weights
=
sess
.
run
([
next_labels
,
next_weights
])
self
.
assertAllEqual
([
0
,
1
,
2
,
3
],
labels
)
self
.
assertAllClose
([
1
,
1
,
1
,
1
],
weights
)
labels
,
weights
=
sess
.
run
([
next_labels
,
next_weights
])
self
.
assertAllEqual
([
4
,
5
,
6
,
7
],
labels
)
self
.
assertAllClose
([
1
,
1
,
1
,
1
],
weights
)
labels
,
weights
=
sess
.
run
([
next_labels
,
next_weights
])
self
.
assertAllEqual
([
8
,
9
,
0
,
0
],
labels
)
self
.
assertAllClose
([
1
,
1
,
0
,
0
],
weights
)
with
self
.
assertRaises
(
tf
.
errors
.
OutOfRangeError
):
sess
.
run
([
next_labels
,
next_weights
])
def
testPadDatasetToBatchSizeWithWeights
(
self
):
values
=
{
"labels"
:
np
.
arange
(
10
,
dtype
=
np
.
int32
),
"weights"
:
100
+
np
.
arange
(
10
,
dtype
=
np
.
int32
)
}
dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
values
).
batch
(
4
)
self
.
assertItemsEqual
([
"labels"
,
"weights"
],
dataset
.
output_shapes
.
keys
())
self
.
assertFalse
(
dataset
.
output_shapes
[
"labels"
].
is_fully_defined
())
self
.
assertFalse
(
dataset
.
output_shapes
[
"weights"
].
is_fully_defined
())
dataset_pad
=
dataset_ops
.
pad_dataset_to_batch_size
(
dataset
,
4
)
self
.
assertItemsEqual
([
"labels"
,
"weights"
],
dataset_pad
.
output_shapes
.
keys
())
self
.
assertEqual
([
4
],
dataset_pad
.
output_shapes
[
"labels"
])
self
.
assertEqual
([
4
],
dataset_pad
.
output_shapes
[
"weights"
])
next_batch
=
dataset_pad
.
make_one_shot_iterator
().
get_next
()
next_labels
=
next_batch
[
"labels"
]
next_weights
=
next_batch
[
"weights"
]
with
self
.
test_session
()
as
sess
:
labels
,
weights
=
sess
.
run
([
next_labels
,
next_weights
])
self
.
assertAllEqual
([
0
,
1
,
2
,
3
],
labels
)
self
.
assertAllEqual
([
100
,
101
,
102
,
103
],
weights
)
labels
,
weights
=
sess
.
run
([
next_labels
,
next_weights
])
self
.
assertAllEqual
([
4
,
5
,
6
,
7
],
labels
)
self
.
assertAllEqual
([
104
,
105
,
106
,
107
],
weights
)
labels
,
weights
=
sess
.
run
([
next_labels
,
next_weights
])
self
.
assertAllEqual
([
8
,
9
,
0
,
0
],
labels
)
self
.
assertAllEqual
([
108
,
109
,
0
,
0
],
weights
)
with
self
.
assertRaises
(
tf
.
errors
.
OutOfRangeError
):
sess
.
run
([
next_labels
,
next_weights
])
def
testSetBatchSizeSingleTensor1d
(
self
):
dataset
=
tf
.
data
.
Dataset
.
range
(
4
).
batch
(
2
)
self
.
assertFalse
(
dataset
.
output_shapes
.
is_fully_defined
())
dataset
=
dataset_ops
.
set_batch_size
(
dataset
,
2
)
self
.
assertEqual
([
2
],
dataset
.
output_shapes
)
next_batch
=
dataset
.
make_one_shot_iterator
().
get_next
()
with
self
.
test_session
()
as
sess
:
batch_value
=
sess
.
run
(
next_batch
)
self
.
assertAllEqual
([
0
,
1
],
batch_value
)
batch_value
=
sess
.
run
(
next_batch
)
self
.
assertAllEqual
([
2
,
3
],
batch_value
)
with
self
.
assertRaises
(
tf
.
errors
.
OutOfRangeError
):
sess
.
run
(
next_batch
)
def
testSetBatchSizeSingleTensor2d
(
self
):
values
=
np
.
arange
(
12
,
dtype
=
np
.
int32
).
reshape
([
4
,
3
])
dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
values
).
batch
(
2
)
self
.
assertFalse
(
dataset
.
output_shapes
.
is_fully_defined
())
dataset
=
dataset_ops
.
set_batch_size
(
dataset
,
2
)
self
.
assertEqual
([
2
,
3
],
dataset
.
output_shapes
)
next_batch
=
dataset
.
make_one_shot_iterator
().
get_next
()
with
self
.
test_session
()
as
sess
:
batch_value
=
sess
.
run
(
next_batch
)
self
.
assertAllEqual
([[
0
,
1
,
2
],
[
3
,
4
,
5
]],
batch_value
)
batch_value
=
sess
.
run
(
next_batch
)
self
.
assertAllEqual
([[
6
,
7
,
8
],
[
9
,
10
,
11
]],
batch_value
)
with
self
.
assertRaises
(
tf
.
errors
.
OutOfRangeError
):
sess
.
run
(
next_batch
)
def
testSetBatchSizeNested
(
self
):
values
=
{
"a"
:
100
+
np
.
arange
(
4
,
dtype
=
np
.
int32
),
"nest"
:
{
"b"
:
np
.
arange
(
12
,
dtype
=
np
.
int32
).
reshape
([
4
,
3
]),
"c"
:
np
.
arange
(
4
,
dtype
=
np
.
int32
)
}
}
dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
values
).
batch
(
2
)
self
.
assertItemsEqual
([
"a"
,
"nest"
],
dataset
.
output_shapes
.
keys
())
self
.
assertItemsEqual
([
"b"
,
"c"
],
dataset
.
output_shapes
[
"nest"
].
keys
())
self
.
assertFalse
(
dataset
.
output_shapes
[
"a"
].
is_fully_defined
())
self
.
assertFalse
(
dataset
.
output_shapes
[
"nest"
][
"b"
].
is_fully_defined
())
self
.
assertFalse
(
dataset
.
output_shapes
[
"nest"
][
"c"
].
is_fully_defined
())
dataset
=
dataset_ops
.
set_batch_size
(
dataset
,
2
)
self
.
assertItemsEqual
([
"a"
,
"nest"
],
dataset
.
output_shapes
.
keys
())
self
.
assertItemsEqual
([
"b"
,
"c"
],
dataset
.
output_shapes
[
"nest"
].
keys
())
self
.
assertEqual
([
2
],
dataset
.
output_shapes
[
"a"
])
self
.
assertEqual
([
2
,
3
],
dataset
.
output_shapes
[
"nest"
][
"b"
])
self
.
assertEqual
([
2
],
dataset
.
output_shapes
[
"nest"
][
"c"
])
next_batch
=
dataset
.
make_one_shot_iterator
().
get_next
()
next_a
=
next_batch
[
"a"
]
next_b
=
next_batch
[
"nest"
][
"b"
]
next_c
=
next_batch
[
"nest"
][
"c"
]
with
self
.
test_session
()
as
sess
:
a
,
b
,
c
=
sess
.
run
([
next_a
,
next_b
,
next_c
])
self
.
assertAllEqual
([
100
,
101
],
a
)
self
.
assertAllEqual
([[
0
,
1
,
2
],
[
3
,
4
,
5
]],
b
)
self
.
assertAllEqual
([
0
,
1
],
c
)
a
,
b
,
c
=
sess
.
run
([
next_a
,
next_b
,
next_c
])
self
.
assertAllEqual
([
102
,
103
],
a
)
self
.
assertAllEqual
([[
6
,
7
,
8
],
[
9
,
10
,
11
]],
b
)
self
.
assertAllEqual
([
2
,
3
],
c
)
with
self
.
assertRaises
(
tf
.
errors
.
OutOfRangeError
):
sess
.
run
(
next_batch
)
class
BuildDatasetTest
(
tf
.
test
.
TestCase
):
def
setUp
(
self
):
super
(
BuildDatasetTest
,
self
).
setUp
()
# The test dataset contains 10 tensorflow.Example protocol buffers. The i-th
# Example contains the following features:
# global_view = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]
# local_view = [0.0, 1.0, 2.0, 3.0]
# aux_feature = 100 + i
# label_str = "PC" if i % 3 == 0 else "AFP" if i % 3 == 1 else "NTP"
self
.
_file_pattern
=
os
.
path
.
join
(
FLAGS
.
test_srcdir
,
_TEST_TFRECORD_FILE
)
self
.
_input_config
=
configdict
.
ConfigDict
({
"features"
:
{
"global_view"
:
{
"is_time_series"
:
True
,
"length"
:
8
},
"local_view"
:
{
"is_time_series"
:
True
,
"length"
:
4
},
"aux_feature"
:
{
"is_time_series"
:
False
,
"length"
:
1
}
}
})
def
testNonExistentFileRaisesValueError
(
self
):
with
self
.
assertRaises
(
ValueError
):
dataset_ops
.
build_dataset
(
file_pattern
=
"nonexistent"
,
input_config
=
self
.
_input_config
,
batch_size
=
4
)
def
testBuildWithoutLabels
(
self
):
dataset
=
dataset_ops
.
build_dataset
(
file_pattern
=
self
.
_file_pattern
,
input_config
=
self
.
_input_config
,
batch_size
=
4
,
include_labels
=
False
)
# We can use a one-shot iterator without labels because we don't have the
# stateful hash map for label ids.
iterator
=
dataset
.
make_one_shot_iterator
()
features
=
iterator
.
get_next
()
# Expect features only.
self
.
assertItemsEqual
([
"time_series_features"
,
"aux_features"
],
features
.
keys
())
with
self
.
test_session
()
as
sess
:
# Batch 1.
f
=
sess
.
run
(
features
)
np
.
testing
.
assert_array_almost_equal
([
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
],
f
[
"time_series_features"
][
"global_view"
])
np
.
testing
.
assert_array_almost_equal
([
[
0
,
1
,
2
,
3
],
[
0
,
1
,
2
,
3
],
[
0
,
1
,
2
,
3
],
[
0
,
1
,
2
,
3
],
],
f
[
"time_series_features"
][
"local_view"
])
np
.
testing
.
assert_array_almost_equal
([[
100
],
[
101
],
[
102
],
[
103
]],
f
[
"aux_features"
][
"aux_feature"
])
# Batch 2.
f
=
sess
.
run
(
features
)
np
.
testing
.
assert_array_almost_equal
([
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
],
f
[
"time_series_features"
][
"global_view"
])
np
.
testing
.
assert_array_almost_equal
([
[
0
,
1
,
2
,
3
],
[
0
,
1
,
2
,
3
],
[
0
,
1
,
2
,
3
],
[
0
,
1
,
2
,
3
],
],
f
[
"time_series_features"
][
"local_view"
])
np
.
testing
.
assert_array_almost_equal
([[
104
],
[
105
],
[
106
],
[
107
]],
f
[
"aux_features"
][
"aux_feature"
])
# Batch 3.
f
=
sess
.
run
(
features
)
np
.
testing
.
assert_array_almost_equal
([
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
],
f
[
"time_series_features"
][
"global_view"
])
np
.
testing
.
assert_array_almost_equal
([
[
0
,
1
,
2
,
3
],
[
0
,
1
,
2
,
3
],
],
f
[
"time_series_features"
][
"local_view"
])
np
.
testing
.
assert_array_almost_equal
([[
108
],
[
109
]],
f
[
"aux_features"
][
"aux_feature"
])
# No more batches.
with
self
.
assertRaises
(
tf
.
errors
.
OutOfRangeError
):
sess
.
run
(
features
)
def
testLabels1
(
self
):
self
.
_input_config
[
"label_feature"
]
=
"label_str"
self
.
_input_config
[
"label_map"
]
=
{
"PC"
:
0
,
"AFP"
:
1
,
"NTP"
:
2
}
dataset
=
dataset_ops
.
build_dataset
(
file_pattern
=
self
.
_file_pattern
,
input_config
=
self
.
_input_config
,
batch_size
=
4
)
# We need an initializable iterator when using labels because of the
# stateful label id hash table.
iterator
=
dataset
.
make_initializable_iterator
()
inputs
=
iterator
.
get_next
()
init_op
=
tf
.
tables_initializer
()
# Expect features and labels.
self
.
assertItemsEqual
([
"time_series_features"
,
"aux_features"
,
"labels"
],
inputs
.
keys
())
labels
=
inputs
[
"labels"
]
with
self
.
test_session
()
as
sess
:
sess
.
run
([
init_op
,
iterator
.
initializer
])
# Fetch 3 batches.
np
.
testing
.
assert_array_equal
([
0
,
1
,
2
,
0
],
sess
.
run
(
labels
))
np
.
testing
.
assert_array_equal
([
1
,
2
,
0
,
1
],
sess
.
run
(
labels
))
np
.
testing
.
assert_array_equal
([
2
,
0
],
sess
.
run
(
labels
))
# No more batches.
with
self
.
assertRaises
(
tf
.
errors
.
OutOfRangeError
):
sess
.
run
(
labels
)
def
testLabels2
(
self
):
self
.
_input_config
[
"label_feature"
]
=
"label_str"
self
.
_input_config
[
"label_map"
]
=
{
"PC"
:
1
,
"AFP"
:
0
,
"NTP"
:
0
}
dataset
=
dataset_ops
.
build_dataset
(
file_pattern
=
self
.
_file_pattern
,
input_config
=
self
.
_input_config
,
batch_size
=
4
)
# We need an initializable iterator when using labels because of the
# stateful label id hash table.
iterator
=
dataset
.
make_initializable_iterator
()
inputs
=
iterator
.
get_next
()
init_op
=
tf
.
tables_initializer
()
# Expect features and labels.
self
.
assertItemsEqual
([
"time_series_features"
,
"aux_features"
,
"labels"
],
inputs
.
keys
())
labels
=
inputs
[
"labels"
]
with
self
.
test_session
()
as
sess
:
sess
.
run
([
init_op
,
iterator
.
initializer
])
# Fetch 3 batches.
np
.
testing
.
assert_array_equal
([
1
,
0
,
0
,
1
],
sess
.
run
(
labels
))
np
.
testing
.
assert_array_equal
([
0
,
0
,
1
,
0
],
sess
.
run
(
labels
))
np
.
testing
.
assert_array_equal
([
0
,
1
],
sess
.
run
(
labels
))
# No more batches.
with
self
.
assertRaises
(
tf
.
errors
.
OutOfRangeError
):
sess
.
run
(
labels
)
def
testBadLabelIdsRaisesValueError
(
self
):
self
.
_input_config
[
"label_feature"
]
=
"label_str"
# Label ids should be contiguous integers starting at 0.
self
.
_input_config
[
"label_map"
]
=
{
"PC"
:
1
,
"AFP"
:
2
,
"NTP"
:
3
}
with
self
.
assertRaises
(
ValueError
):
dataset_ops
.
build_dataset
(
file_pattern
=
self
.
_file_pattern
,
input_config
=
self
.
_input_config
,
batch_size
=
4
)
def
testUnknownLabel
(
self
):
self
.
_input_config
[
"label_feature"
]
=
"label_str"
# label_map does not include "NTP".
self
.
_input_config
[
"label_map"
]
=
{
"PC"
:
1
,
"AFP"
:
0
}
dataset
=
dataset_ops
.
build_dataset
(
file_pattern
=
self
.
_file_pattern
,
input_config
=
self
.
_input_config
,
batch_size
=
4
)
# We need an initializable iterator when using labels because of the
# stateful label id hash table.
iterator
=
dataset
.
make_initializable_iterator
()
inputs
=
iterator
.
get_next
()
init_op
=
tf
.
tables_initializer
()
# Expect features and labels.
self
.
assertItemsEqual
([
"time_series_features"
,
"aux_features"
,
"labels"
],
inputs
.
keys
())
labels
=
inputs
[
"labels"
]
with
self
.
test_session
()
as
sess
:
sess
.
run
([
init_op
,
iterator
.
initializer
])
# Unknown label "NTP".
with
self
.
assertRaises
(
tf
.
errors
.
InvalidArgumentError
):
sess
.
run
(
labels
)
def
testReverseTimeSeries
(
self
):
dataset
=
dataset_ops
.
build_dataset
(
file_pattern
=
self
.
_file_pattern
,
input_config
=
self
.
_input_config
,
batch_size
=
4
,
reverse_time_series_prob
=
1
,
include_labels
=
False
)
# We can use a one-shot iterator without labels because we don't have the
# stateful hash map for label ids.
iterator
=
dataset
.
make_one_shot_iterator
()
features
=
iterator
.
get_next
()
# Expect features only.
self
.
assertItemsEqual
([
"time_series_features"
,
"aux_features"
],
features
.
keys
())
with
self
.
test_session
()
as
sess
:
# Batch 1.
f
=
sess
.
run
(
features
)
np
.
testing
.
assert_array_almost_equal
([
[
7
,
6
,
5
,
4
,
3
,
2
,
1
,
0
],
[
7
,
6
,
5
,
4
,
3
,
2
,
1
,
0
],
[
7
,
6
,
5
,
4
,
3
,
2
,
1
,
0
],
[
7
,
6
,
5
,
4
,
3
,
2
,
1
,
0
],
],
f
[
"time_series_features"
][
"global_view"
])
np
.
testing
.
assert_array_almost_equal
([
[
3
,
2
,
1
,
0
],
[
3
,
2
,
1
,
0
],
[
3
,
2
,
1
,
0
],
[
3
,
2
,
1
,
0
],
],
f
[
"time_series_features"
][
"local_view"
])
np
.
testing
.
assert_array_almost_equal
([[
100
],
[
101
],
[
102
],
[
103
]],
f
[
"aux_features"
][
"aux_feature"
])
# Batch 2.
f
=
sess
.
run
(
features
)
np
.
testing
.
assert_array_almost_equal
([
[
7
,
6
,
5
,
4
,
3
,
2
,
1
,
0
],
[
7
,
6
,
5
,
4
,
3
,
2
,
1
,
0
],
[
7
,
6
,
5
,
4
,
3
,
2
,
1
,
0
],
[
7
,
6
,
5
,
4
,
3
,
2
,
1
,
0
],
],
f
[
"time_series_features"
][
"global_view"
])
np
.
testing
.
assert_array_almost_equal
([
[
3
,
2
,
1
,
0
],
[
3
,
2
,
1
,
0
],
[
3
,
2
,
1
,
0
],
[
3
,
2
,
1
,
0
],
],
f
[
"time_series_features"
][
"local_view"
])
np
.
testing
.
assert_array_almost_equal
([[
104
],
[
105
],
[
106
],
[
107
]],
f
[
"aux_features"
][
"aux_feature"
])
# Batch 3.
f
=
sess
.
run
(
features
)
np
.
testing
.
assert_array_almost_equal
([
[
7
,
6
,
5
,
4
,
3
,
2
,
1
,
0
],
[
7
,
6
,
5
,
4
,
3
,
2
,
1
,
0
],
],
f
[
"time_series_features"
][
"global_view"
])
np
.
testing
.
assert_array_almost_equal
([
[
3
,
2
,
1
,
0
],
[
3
,
2
,
1
,
0
],
],
f
[
"time_series_features"
][
"local_view"
])
np
.
testing
.
assert_array_almost_equal
([[
108
],
[
109
]],
f
[
"aux_features"
][
"aux_feature"
])
# No more batches.
with
self
.
assertRaises
(
tf
.
errors
.
OutOfRangeError
):
sess
.
run
(
features
)
def
testRepeat
(
self
):
dataset
=
dataset_ops
.
build_dataset
(
file_pattern
=
self
.
_file_pattern
,
input_config
=
self
.
_input_config
,
batch_size
=
4
,
include_labels
=
False
)
# We can use a one-shot iterator without labels because we don't have the
# stateful hash map for label ids.
iterator
=
dataset
.
make_one_shot_iterator
()
features
=
iterator
.
get_next
()
# Expect features only.
self
.
assertItemsEqual
([
"time_series_features"
,
"aux_features"
],
features
.
keys
())
with
self
.
test_session
()
as
sess
:
# Batch 1.
f
=
sess
.
run
(
features
)
np
.
testing
.
assert_array_almost_equal
([
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
],
f
[
"time_series_features"
][
"global_view"
])
np
.
testing
.
assert_array_almost_equal
([
[
0
,
1
,
2
,
3
],
[
0
,
1
,
2
,
3
],
[
0
,
1
,
2
,
3
],
[
0
,
1
,
2
,
3
],
],
f
[
"time_series_features"
][
"local_view"
])
np
.
testing
.
assert_array_almost_equal
([[
100
],
[
101
],
[
102
],
[
103
]],
f
[
"aux_features"
][
"aux_feature"
])
# Batch 2.
f
=
sess
.
run
(
features
)
np
.
testing
.
assert_array_almost_equal
([
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
],
f
[
"time_series_features"
][
"global_view"
])
np
.
testing
.
assert_array_almost_equal
([
[
0
,
1
,
2
,
3
],
[
0
,
1
,
2
,
3
],
[
0
,
1
,
2
,
3
],
[
0
,
1
,
2
,
3
],
],
f
[
"time_series_features"
][
"local_view"
])
np
.
testing
.
assert_array_almost_equal
([[
104
],
[
105
],
[
106
],
[
107
]],
f
[
"aux_features"
][
"aux_feature"
])
# Batch 3.
f
=
sess
.
run
(
features
)
np
.
testing
.
assert_array_almost_equal
([
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
],
f
[
"time_series_features"
][
"global_view"
])
np
.
testing
.
assert_array_almost_equal
([
[
0
,
1
,
2
,
3
],
[
0
,
1
,
2
,
3
],
],
f
[
"time_series_features"
][
"local_view"
])
np
.
testing
.
assert_array_almost_equal
([[
108
],
[
109
]],
f
[
"aux_features"
][
"aux_feature"
])
# No more batches.
with
self
.
assertRaises
(
tf
.
errors
.
OutOfRangeError
):
sess
.
run
(
features
)
def
testTPU
(
self
):
dataset
=
dataset_ops
.
build_dataset
(
file_pattern
=
self
.
_file_pattern
,
input_config
=
self
.
_input_config
,
batch_size
=
4
,
include_labels
=
False
)
# We can use a one-shot iterator without labels because we don't have the
# stateful hash map for label ids.
iterator
=
dataset
.
make_one_shot_iterator
()
features
=
iterator
.
get_next
()
# Expect features only.
self
.
assertItemsEqual
([
"time_series_features"
,
"aux_features"
],
features
.
keys
())
with
self
.
test_session
()
as
sess
:
# Batch 1.
f
=
sess
.
run
(
features
)
np
.
testing
.
assert_array_almost_equal
([
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
],
f
[
"time_series_features"
][
"global_view"
])
np
.
testing
.
assert_array_almost_equal
([
[
0
,
1
,
2
,
3
],
[
0
,
1
,
2
,
3
],
[
0
,
1
,
2
,
3
],
[
0
,
1
,
2
,
3
],
],
f
[
"time_series_features"
][
"local_view"
])
np
.
testing
.
assert_array_almost_equal
([[
100
],
[
101
],
[
102
],
[
103
]],
f
[
"aux_features"
][
"aux_feature"
])
# Batch 2.
f
=
sess
.
run
(
features
)
np
.
testing
.
assert_array_almost_equal
([
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
],
f
[
"time_series_features"
][
"global_view"
])
np
.
testing
.
assert_array_almost_equal
([
[
0
,
1
,
2
,
3
],
[
0
,
1
,
2
,
3
],
[
0
,
1
,
2
,
3
],
[
0
,
1
,
2
,
3
],
],
f
[
"time_series_features"
][
"local_view"
])
np
.
testing
.
assert_array_almost_equal
([[
104
],
[
105
],
[
106
],
[
107
]],
f
[
"aux_features"
][
"aux_feature"
])
# Batch 3.
f
=
sess
.
run
(
features
)
np
.
testing
.
assert_array_almost_equal
([
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
],
f
[
"time_series_features"
][
"global_view"
])
np
.
testing
.
assert_array_almost_equal
([
[
0
,
1
,
2
,
3
],
[
0
,
1
,
2
,
3
],
],
f
[
"time_series_features"
][
"local_view"
])
np
.
testing
.
assert_array_almost_equal
([[
108
],
[
109
]],
f
[
"aux_features"
][
"aux_feature"
])
# No more batches.
with
self
.
assertRaises
(
tf
.
errors
.
OutOfRangeError
):
sess
.
run
(
features
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
research/astronet/astronet/ops/input_ops.py
deleted
100644 → 0
View file @
17c2f0cc
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Operations for feeding input data using TensorFlow placeholders."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
def
prepare_feed_dict
(
model
,
features
,
labels
=
None
,
is_training
=
None
):
"""Prepares a feed_dict for sess.run() given a batch of features and labels.
Args:
model: An instance of AstroModel.
features: Dictionary containing "time_series_features" and "aux_features".
Each is a dictionary of named numpy arrays of shape [batch_size, length].
labels: (Optional). Numpy array of shape [batch_size].
is_training: (Optional). Python boolean to feed to the model.is_training
Tensor (if None, no value is fed).
Returns:
feed_dict: A dictionary of input Tensor to numpy array.
"""
feed_dict
=
{}
for
feature
,
tensor
in
model
.
time_series_features
.
items
():
feed_dict
[
tensor
]
=
features
[
"time_series_features"
][
feature
]
for
feature
,
tensor
in
model
.
aux_features
.
items
():
feed_dict
[
tensor
]
=
features
[
"aux_features"
][
feature
]
if
labels
is
not
None
:
feed_dict
[
model
.
labels
]
=
labels
if
is_training
is
not
None
:
feed_dict
[
model
.
is_training
]
=
is_training
return
feed_dict
def
build_feature_placeholders
(
config
):
"""Builds tf.Placeholder ops for feeding model features and labels.
Args:
config: ConfigDict containing the feature configurations.
Returns:
features: A dictionary containing "time_series_features" and "aux_features",
each of which is a dictionary of tf.Placeholders of features from the
input configuration. All features have dtype float32 and shape
[batch_size, length].
"""
batch_size
=
None
# Batch size will be dynamically specified.
features
=
{
"time_series_features"
:
{},
"aux_features"
:
{}}
for
feature_name
,
feature_spec
in
config
.
items
():
placeholder
=
tf
.
placeholder
(
dtype
=
tf
.
float32
,
shape
=
[
batch_size
,
feature_spec
.
length
],
name
=
feature_name
)
if
feature_spec
.
is_time_series
:
features
[
"time_series_features"
][
feature_name
]
=
placeholder
else
:
features
[
"aux_features"
][
feature_name
]
=
placeholder
return
features
def
build_labels_placeholder
():
"""Builds a tf.Placeholder op for feeding model labels.
Returns:
labels: An int64 tf.Placeholder with shape [batch_size].
"""
batch_size
=
None
# Batch size will be dynamically specified.
return
tf
.
placeholder
(
dtype
=
tf
.
int64
,
shape
=
[
batch_size
],
name
=
"labels"
)
research/astronet/astronet/ops/input_ops_test.py
deleted
100644 → 0
View file @
17c2f0cc
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for input_ops."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
from
astronet.ops
import
input_ops
from
tf_util
import
configdict
class
InputOpsTest
(
tf
.
test
.
TestCase
):
def
assertFeatureShapesEqual
(
self
,
expected_shapes
,
features
):
"""Asserts that a dict of feature placeholders has the expected shapes.
Args:
expected_shapes: Dictionary of expected Tensor shapes, as lists,
corresponding to the structure of 'features'.
features: Dictionary of feature placeholders of the format returned by
input_ops.build_feature_placeholders().
"""
actual_shapes
=
{}
for
feature_type
in
features
:
actual_shapes
[
feature_type
]
=
{
feature
:
tensor
.
shape
.
as_list
()
for
feature
,
tensor
in
features
[
feature_type
].
items
()
}
self
.
assertDictEqual
(
expected_shapes
,
actual_shapes
)
def
testBuildFeaturePlaceholders
(
self
):
# One time series feature.
config
=
configdict
.
ConfigDict
({
"time_feature_1"
:
{
"length"
:
14
,
"is_time_series"
:
True
,
}
})
expected_shapes
=
{
"time_series_features"
:
{
"time_feature_1"
:
[
None
,
14
],
},
"aux_features"
:
{}
}
features
=
input_ops
.
build_feature_placeholders
(
config
)
self
.
assertFeatureShapesEqual
(
expected_shapes
,
features
)
# Two time series features.
config
=
configdict
.
ConfigDict
({
"time_feature_1"
:
{
"length"
:
14
,
"is_time_series"
:
True
,
},
"time_feature_2"
:
{
"length"
:
5
,
"is_time_series"
:
True
,
}
})
expected_shapes
=
{
"time_series_features"
:
{
"time_feature_1"
:
[
None
,
14
],
"time_feature_2"
:
[
None
,
5
],
},
"aux_features"
:
{}
}
features
=
input_ops
.
build_feature_placeholders
(
config
)
self
.
assertFeatureShapesEqual
(
expected_shapes
,
features
)
# One aux feature.
config
=
configdict
.
ConfigDict
({
"time_feature_1"
:
{
"length"
:
14
,
"is_time_series"
:
True
,
},
"aux_feature_1"
:
{
"length"
:
1
,
"is_time_series"
:
False
,
}
})
expected_shapes
=
{
"time_series_features"
:
{
"time_feature_1"
:
[
None
,
14
],
},
"aux_features"
:
{
"aux_feature_1"
:
[
None
,
1
]
}
}
features
=
input_ops
.
build_feature_placeholders
(
config
)
self
.
assertFeatureShapesEqual
(
expected_shapes
,
features
)
# Two aux features.
config
=
configdict
.
ConfigDict
({
"time_feature_1"
:
{
"length"
:
14
,
"is_time_series"
:
True
,
},
"aux_feature_1"
:
{
"length"
:
1
,
"is_time_series"
:
False
,
},
"aux_feature_2"
:
{
"length"
:
6
,
"is_time_series"
:
False
,
},
})
expected_shapes
=
{
"time_series_features"
:
{
"time_feature_1"
:
[
None
,
14
],
},
"aux_features"
:
{
"aux_feature_1"
:
[
None
,
1
],
"aux_feature_2"
:
[
None
,
6
]
}
}
features
=
input_ops
.
build_feature_placeholders
(
config
)
self
.
assertFeatureShapesEqual
(
expected_shapes
,
features
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
research/astronet/astronet/ops/metrics.py
deleted
100644 → 0
View file @
17c2f0cc
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for computing evaluation metrics."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
def
_metric_variable
(
name
,
shape
,
dtype
):
"""Creates a Variable in LOCAL_VARIABLES and METRIC_VARIABLES collections."""
return
tf
.
get_variable
(
name
,
initializer
=
tf
.
zeros
(
shape
,
dtype
),
trainable
=
False
,
collections
=
[
tf
.
GraphKeys
.
LOCAL_VARIABLES
,
tf
.
GraphKeys
.
METRIC_VARIABLES
])
def
_build_metrics
(
labels
,
predictions
,
weights
,
batch_losses
,
output_dim
=
1
):
"""Builds TensorFlow operations to compute model evaluation metrics.
Args:
labels: Tensor with shape [batch_size].
predictions: Tensor with shape [batch_size, output_dim].
weights: Tensor with shape [batch_size].
batch_losses: Tensor with shape [batch_size].
output_dim: Dimension of model output
Returns:
A dictionary {metric_name: (metric_value, update_op).
"""
# Compute the predicted labels.
assert
len
(
predictions
.
shape
)
==
2
binary_classification
=
output_dim
==
1
if
binary_classification
:
assert
predictions
.
shape
[
1
]
==
1
predictions
=
tf
.
squeeze
(
predictions
,
axis
=
[
1
])
predicted_labels
=
tf
.
to_int32
(
tf
.
greater
(
predictions
,
0.5
),
name
=
"predicted_labels"
)
else
:
predicted_labels
=
tf
.
argmax
(
predictions
,
1
,
name
=
"predicted_labels"
,
output_type
=
tf
.
int32
)
metrics
=
{}
with
tf
.
variable_scope
(
"metrics"
):
# Total number of examples.
num_examples
=
_metric_variable
(
"num_examples"
,
[],
tf
.
float32
)
update_num_examples
=
tf
.
assign_add
(
num_examples
,
tf
.
reduce_sum
(
weights
))
metrics
[
"num_examples"
]
=
(
num_examples
.
read_value
(),
update_num_examples
)
# Accuracy metrics.
num_correct
=
_metric_variable
(
"num_correct"
,
[],
tf
.
float32
)
is_correct
=
weights
*
tf
.
to_float
(
tf
.
equal
(
labels
,
predicted_labels
))
update_num_correct
=
tf
.
assign_add
(
num_correct
,
tf
.
reduce_sum
(
is_correct
))
metrics
[
"accuracy/num_correct"
]
=
(
num_correct
.
read_value
(),
update_num_correct
)
accuracy
=
tf
.
div
(
num_correct
,
num_examples
,
name
=
"accuracy"
)
metrics
[
"accuracy/accuracy"
]
=
(
accuracy
,
tf
.
no_op
())
# Weighted cross-entropy loss.
metrics
[
"losses/weighted_cross_entropy"
]
=
tf
.
metrics
.
mean
(
batch_losses
,
weights
=
weights
,
name
=
"cross_entropy_loss"
)
def
_count_condition
(
name
,
labels_value
,
predicted_value
):
"""Creates a counter for given values of predictions and labels."""
count
=
_metric_variable
(
name
,
[],
tf
.
float32
)
is_equal
=
tf
.
to_float
(
tf
.
logical_and
(
tf
.
equal
(
labels
,
labels_value
),
tf
.
equal
(
predicted_labels
,
predicted_value
)))
update_op
=
tf
.
assign_add
(
count
,
tf
.
reduce_sum
(
weights
*
is_equal
))
return
count
.
read_value
(),
update_op
# Confusion matrix metrics.
num_labels
=
2
if
binary_classification
else
output_dim
for
gold_label
in
range
(
num_labels
):
for
pred_label
in
range
(
num_labels
):
metric_name
=
"confusion_matrix/label_{}_pred_{}"
.
format
(
gold_label
,
pred_label
)
metrics
[
metric_name
]
=
_count_condition
(
metric_name
,
labels_value
=
gold_label
,
predicted_value
=
pred_label
)
# Possibly create AUC metric for binary classification.
if
binary_classification
:
labels
=
tf
.
cast
(
labels
,
dtype
=
tf
.
bool
)
metrics
[
"auc"
]
=
tf
.
metrics
.
auc
(
labels
,
predictions
,
weights
=
weights
,
num_thresholds
=
1000
)
return
metrics
def
create_metric_fn
(
model
):
"""Creates a tuple (metric_fn, metric_fn_inputs).
This function is primarily used for creating a TPUEstimator.
The result of calling metric_fn(**metric_fn_inputs) is a dictionary
{metric_name: (metric_value, update_op)}.
Args:
model: Instance of AstroModel.
Returns:
A tuple (metric_fn, metric_fn_inputs).
"""
weights
=
model
.
weights
if
weights
is
None
:
weights
=
tf
.
ones_like
(
model
.
labels
,
dtype
=
tf
.
float32
)
metric_fn_inputs
=
{
"labels"
:
model
.
labels
,
"predictions"
:
model
.
predictions
,
"weights"
:
weights
,
"batch_losses"
:
model
.
batch_losses
,
}
def
metric_fn
(
labels
,
predictions
,
weights
,
batch_losses
):
return
_build_metrics
(
labels
,
predictions
,
weights
,
batch_losses
,
output_dim
=
model
.
hparams
.
output_dim
)
return
metric_fn
,
metric_fn_inputs
def
create_metrics
(
model
):
"""Creates a dictionary {metric_name: (metric_value, update_op)}.
This function is primarily used for creating an Estimator.
Args:
model: Instance of AstroModel.
Returns:
A dictionary {metric_name: (metric_value, update_op).
"""
metric_fn
,
metric_fn_inputs
=
create_metric_fn
(
model
)
return
metric_fn
(
**
metric_fn_inputs
)
research/astronet/astronet/ops/metrics_test.py
deleted
100644 → 0
View file @
17c2f0cc
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for metrics.py."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
from
astronet.ops
import
metrics
def
_unpack_metric_map
(
names_to_tuples
):
"""Unpacks {metric_name: (metric_value, update_op)} into separate dicts."""
metric_names
=
names_to_tuples
.
keys
()
value_ops
,
update_ops
=
zip
(
*
names_to_tuples
.
values
())
return
dict
(
zip
(
metric_names
,
value_ops
)),
dict
(
zip
(
metric_names
,
update_ops
))
class
_MockHparams
(
object
):
"""Mock Hparams class to support accessing with dot notation."""
pass
class
_MockModel
(
object
):
"""Mock model for testing."""
def
__init__
(
self
,
labels
,
predictions
,
weights
,
batch_losses
,
output_dim
):
self
.
labels
=
tf
.
constant
(
labels
,
dtype
=
tf
.
int32
)
self
.
predictions
=
tf
.
constant
(
predictions
,
dtype
=
tf
.
float32
)
self
.
weights
=
None
if
weights
is
None
else
tf
.
constant
(
weights
,
dtype
=
tf
.
float32
)
self
.
batch_losses
=
tf
.
constant
(
batch_losses
,
dtype
=
tf
.
float32
)
self
.
hparams
=
_MockHparams
()
self
.
hparams
.
output_dim
=
output_dim
class
MetricsTest
(
tf
.
test
.
TestCase
):
def
testMultiClassificationWithoutWeights
(
self
):
labels
=
[
0
,
1
,
2
,
3
]
predictions
=
[
[
0.7
,
0.2
,
0.1
,
0.0
],
# Predicted label = 0
[
0.2
,
0.4
,
0.2
,
0.2
],
# Predicted label = 1
[
0.0
,
0.0
,
0.0
,
1.0
],
# Predicted label = 3
[
0.1
,
0.1
,
0.7
,
0.1
],
# Predicted label = 2
]
weights
=
None
batch_losses
=
[
0
,
0
,
4
,
2
]
model
=
_MockModel
(
labels
,
predictions
,
weights
,
batch_losses
,
output_dim
=
4
)
metric_map
=
metrics
.
create_metrics
(
model
)
value_ops
,
update_ops
=
_unpack_metric_map
(
metric_map
)
initializer
=
tf
.
local_variables_initializer
()
with
self
.
test_session
()
as
sess
:
sess
.
run
(
initializer
)
sess
.
run
(
update_ops
)
self
.
assertAllClose
({
"num_examples"
:
4
,
"accuracy/num_correct"
:
2
,
"accuracy/accuracy"
:
0.5
,
"losses/weighted_cross_entropy"
:
1.5
,
"confusion_matrix/label_0_pred_0"
:
1
,
"confusion_matrix/label_0_pred_1"
:
0
,
"confusion_matrix/label_0_pred_2"
:
0
,
"confusion_matrix/label_0_pred_3"
:
0
,
"confusion_matrix/label_1_pred_0"
:
0
,
"confusion_matrix/label_1_pred_1"
:
1
,
"confusion_matrix/label_1_pred_2"
:
0
,
"confusion_matrix/label_1_pred_3"
:
0
,
"confusion_matrix/label_2_pred_0"
:
0
,
"confusion_matrix/label_2_pred_1"
:
0
,
"confusion_matrix/label_2_pred_2"
:
0
,
"confusion_matrix/label_2_pred_3"
:
1
,
"confusion_matrix/label_3_pred_0"
:
0
,
"confusion_matrix/label_3_pred_1"
:
0
,
"confusion_matrix/label_3_pred_2"
:
1
,
"confusion_matrix/label_3_pred_3"
:
0
},
sess
.
run
(
value_ops
))
sess
.
run
(
update_ops
)
self
.
assertAllClose
({
"num_examples"
:
8
,
"accuracy/num_correct"
:
4
,
"accuracy/accuracy"
:
0.5
,
"losses/weighted_cross_entropy"
:
1.5
,
"confusion_matrix/label_0_pred_0"
:
2
,
"confusion_matrix/label_0_pred_1"
:
0
,
"confusion_matrix/label_0_pred_2"
:
0
,
"confusion_matrix/label_0_pred_3"
:
0
,
"confusion_matrix/label_1_pred_0"
:
0
,
"confusion_matrix/label_1_pred_1"
:
2
,
"confusion_matrix/label_1_pred_2"
:
0
,
"confusion_matrix/label_1_pred_3"
:
0
,
"confusion_matrix/label_2_pred_0"
:
0
,
"confusion_matrix/label_2_pred_1"
:
0
,
"confusion_matrix/label_2_pred_2"
:
0
,
"confusion_matrix/label_2_pred_3"
:
2
,
"confusion_matrix/label_3_pred_0"
:
0
,
"confusion_matrix/label_3_pred_1"
:
0
,
"confusion_matrix/label_3_pred_2"
:
2
,
"confusion_matrix/label_3_pred_3"
:
0
},
sess
.
run
(
value_ops
))
def
testMultiClassificationWithWeights
(
self
):
labels
=
[
0
,
1
,
2
,
3
]
predictions
=
[
[
0.7
,
0.2
,
0.1
,
0.0
],
# Predicted label = 0
[
0.2
,
0.4
,
0.2
,
0.2
],
# Predicted label = 1
[
0.0
,
0.0
,
0.0
,
1.0
],
# Predicted label = 3
[
0.1
,
0.1
,
0.7
,
0.1
],
# Predicted label = 2
]
weights
=
[
0
,
1
,
0
,
1
]
batch_losses
=
[
0
,
0
,
4
,
2
]
model
=
_MockModel
(
labels
,
predictions
,
weights
,
batch_losses
,
output_dim
=
4
)
metric_map
=
metrics
.
create_metrics
(
model
)
value_ops
,
update_ops
=
_unpack_metric_map
(
metric_map
)
initializer
=
tf
.
local_variables_initializer
()
with
self
.
test_session
()
as
sess
:
sess
.
run
(
initializer
)
sess
.
run
(
update_ops
)
self
.
assertAllClose
({
"num_examples"
:
2
,
"accuracy/num_correct"
:
1
,
"accuracy/accuracy"
:
0.5
,
"losses/weighted_cross_entropy"
:
1
,
"confusion_matrix/label_0_pred_0"
:
0
,
"confusion_matrix/label_0_pred_1"
:
0
,
"confusion_matrix/label_0_pred_2"
:
0
,
"confusion_matrix/label_0_pred_3"
:
0
,
"confusion_matrix/label_1_pred_0"
:
0
,
"confusion_matrix/label_1_pred_1"
:
1
,
"confusion_matrix/label_1_pred_2"
:
0
,
"confusion_matrix/label_1_pred_3"
:
0
,
"confusion_matrix/label_2_pred_0"
:
0
,
"confusion_matrix/label_2_pred_1"
:
0
,
"confusion_matrix/label_2_pred_2"
:
0
,
"confusion_matrix/label_2_pred_3"
:
0
,
"confusion_matrix/label_3_pred_0"
:
0
,
"confusion_matrix/label_3_pred_1"
:
0
,
"confusion_matrix/label_3_pred_2"
:
1
,
"confusion_matrix/label_3_pred_3"
:
0
},
sess
.
run
(
value_ops
))
sess
.
run
(
update_ops
)
self
.
assertAllClose
({
"num_examples"
:
4
,
"accuracy/num_correct"
:
2
,
"accuracy/accuracy"
:
0.5
,
"losses/weighted_cross_entropy"
:
1
,
"confusion_matrix/label_0_pred_0"
:
0
,
"confusion_matrix/label_0_pred_1"
:
0
,
"confusion_matrix/label_0_pred_2"
:
0
,
"confusion_matrix/label_0_pred_3"
:
0
,
"confusion_matrix/label_1_pred_0"
:
0
,
"confusion_matrix/label_1_pred_1"
:
2
,
"confusion_matrix/label_1_pred_2"
:
0
,
"confusion_matrix/label_1_pred_3"
:
0
,
"confusion_matrix/label_2_pred_0"
:
0
,
"confusion_matrix/label_2_pred_1"
:
0
,
"confusion_matrix/label_2_pred_2"
:
0
,
"confusion_matrix/label_2_pred_3"
:
0
,
"confusion_matrix/label_3_pred_0"
:
0
,
"confusion_matrix/label_3_pred_1"
:
0
,
"confusion_matrix/label_3_pred_2"
:
2
,
"confusion_matrix/label_3_pred_3"
:
0
},
sess
.
run
(
value_ops
))
def
testBinaryClassificationWithoutWeights
(
self
):
labels
=
[
0
,
1
,
1
,
0
]
predictions
=
[
[
0.4
],
# Predicted label = 0
[
0.6
],
# Predicted label = 1
[
0.0
],
# Predicted label = 0
[
1.0
],
# Predicted label = 1
]
weights
=
None
batch_losses
=
[
0
,
0
,
4
,
2
]
model
=
_MockModel
(
labels
,
predictions
,
weights
,
batch_losses
,
output_dim
=
1
)
metric_map
=
metrics
.
create_metrics
(
model
)
value_ops
,
update_ops
=
_unpack_metric_map
(
metric_map
)
initializer
=
tf
.
local_variables_initializer
()
with
self
.
test_session
()
as
sess
:
sess
.
run
(
initializer
)
sess
.
run
(
update_ops
)
self
.
assertAllClose
({
"num_examples"
:
4
,
"accuracy/num_correct"
:
2
,
"accuracy/accuracy"
:
0.5
,
"losses/weighted_cross_entropy"
:
1.5
,
"auc"
:
0.25
,
"confusion_matrix/label_0_pred_0"
:
1
,
"confusion_matrix/label_0_pred_1"
:
1
,
"confusion_matrix/label_1_pred_0"
:
1
,
"confusion_matrix/label_1_pred_1"
:
1
,
},
sess
.
run
(
value_ops
))
sess
.
run
(
update_ops
)
self
.
assertAllClose
({
"num_examples"
:
8
,
"accuracy/num_correct"
:
4
,
"accuracy/accuracy"
:
0.5
,
"losses/weighted_cross_entropy"
:
1.5
,
"auc"
:
0.25
,
"confusion_matrix/label_0_pred_0"
:
2
,
"confusion_matrix/label_0_pred_1"
:
2
,
"confusion_matrix/label_1_pred_0"
:
2
,
"confusion_matrix/label_1_pred_1"
:
2
,
},
sess
.
run
(
value_ops
))
def
testBinaryClassificationWithWeights
(
self
):
labels
=
[
0
,
1
,
1
,
0
]
predictions
=
[
[
0.4
],
# Predicted label = 0
[
0.6
],
# Predicted label = 1
[
0.0
],
# Predicted label = 0
[
1.0
],
# Predicted label = 1
]
weights
=
[
0
,
1
,
0
,
1
]
batch_losses
=
[
0
,
0
,
4
,
2
]
model
=
_MockModel
(
labels
,
predictions
,
weights
,
batch_losses
,
output_dim
=
1
)
metric_map
=
metrics
.
create_metrics
(
model
)
value_ops
,
update_ops
=
_unpack_metric_map
(
metric_map
)
initializer
=
tf
.
local_variables_initializer
()
with
self
.
test_session
()
as
sess
:
sess
.
run
(
initializer
)
sess
.
run
(
update_ops
)
self
.
assertAllClose
({
"num_examples"
:
2
,
"accuracy/num_correct"
:
1
,
"accuracy/accuracy"
:
0.5
,
"losses/weighted_cross_entropy"
:
1
,
"auc"
:
0
,
"confusion_matrix/label_0_pred_0"
:
0
,
"confusion_matrix/label_0_pred_1"
:
1
,
"confusion_matrix/label_1_pred_0"
:
0
,
"confusion_matrix/label_1_pred_1"
:
1
,
},
sess
.
run
(
value_ops
))
sess
.
run
(
update_ops
)
self
.
assertAllClose
({
"num_examples"
:
4
,
"accuracy/num_correct"
:
2
,
"accuracy/accuracy"
:
0.5
,
"losses/weighted_cross_entropy"
:
1
,
"auc"
:
0
,
"confusion_matrix/label_0_pred_0"
:
0
,
"confusion_matrix/label_0_pred_1"
:
2
,
"confusion_matrix/label_1_pred_0"
:
0
,
"confusion_matrix/label_1_pred_1"
:
2
,
},
sess
.
run
(
value_ops
))
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
research/astronet/astronet/ops/test_data/test_dataset.tfrecord
deleted
100644 → 0
View file @
17c2f0cc
File deleted
research/astronet/astronet/ops/testing.py
deleted
100644 → 0
View file @
17c2f0cc
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TensorFlow utilities for unit tests."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
tensorflow
as
tf
def
get_variable_by_name
(
name
,
scope
=
""
):
"""Gets a tf.Variable by name.
Args:
name: Name of the Variable within the specified scope.
scope: Variable scope; use the empty string for top-level scope.
Returns:
The matching tf.Variable object.
"""
with
tf
.
variable_scope
(
scope
,
reuse
=
True
):
return
tf
.
get_variable
(
name
)
def
fake_features
(
feature_spec
,
batch_size
):
"""Creates random numpy arrays representing input features for unit testing.
Args:
feature_spec: Dictionary containing the feature specifications.
batch_size: Integer batch size.
Returns:
Dictionary containing "time_series_features" and "aux_features". Each is a
dictionary of named numpy arrays of shape [batch_size, length].
"""
features
=
{
"time_series_features"
:
{},
"aux_features"
:
{}}
for
name
,
spec
in
feature_spec
.
items
():
ftype
=
"time_series_features"
if
spec
[
"is_time_series"
]
else
"aux_features"
features
[
ftype
][
name
]
=
np
.
random
.
random
([
batch_size
,
spec
[
"length"
]])
return
features
def
fake_labels
(
output_dim
,
batch_size
):
"""Creates a radom numpy array representing labels for unit testing.
Args:
output_dim: Number of output units in the classification model.
batch_size: Integer batch size.
Returns:
Numpy array of shape [batch_size].
"""
# Binary classification is denoted by output_dim == 1. In that case there are
# 2 label classes even though there is only 1 output prediction by the model.
# Otherwise, the classification task is multi-labeled with output_dim classes.
num_labels
=
2
if
output_dim
==
1
else
output_dim
return
np
.
random
.
randint
(
num_labels
,
size
=
batch_size
)
research/astronet/astronet/ops/training.py
deleted
100644 → 0
View file @
17c2f0cc
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for training an AstroNet model."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
def
create_learning_rate
(
hparams
,
global_step
):
"""Creates a learning rate Tensor.
Args:
hparams: ConfigDict containing the learning rate configuration.
global_step: The global step Tensor.
Returns:
A learning rate Tensor.
"""
if
hparams
.
get
(
"learning_rate_decay_factor"
):
learning_rate
=
tf
.
train
.
exponential_decay
(
learning_rate
=
float
(
hparams
.
learning_rate
),
global_step
=
global_step
,
decay_steps
=
hparams
.
learning_rate_decay_steps
,
decay_rate
=
hparams
.
learning_rate_decay_factor
,
staircase
=
hparams
.
learning_rate_decay_staircase
)
else
:
learning_rate
=
tf
.
constant
(
hparams
.
learning_rate
)
return
learning_rate
def
create_optimizer
(
hparams
,
learning_rate
,
use_tpu
=
False
):
"""Creates a TensorFlow Optimizer.
Args:
hparams: ConfigDict containing the optimizer configuration.
learning_rate: A Python float or a scalar Tensor.
use_tpu: If True, the returned optimizer is wrapped in a
CrossShardOptimizer.
Returns:
A TensorFlow optimizer.
Raises:
ValueError: If hparams.optimizer is unrecognized.
"""
optimizer_name
=
hparams
.
optimizer
.
lower
()
if
optimizer_name
==
"momentum"
:
optimizer
=
tf
.
train
.
MomentumOptimizer
(
learning_rate
,
momentum
=
hparams
.
get
(
"momentum"
,
0.9
),
use_nesterov
=
hparams
.
get
(
"use_nesterov"
,
False
))
elif
optimizer_name
==
"sgd"
:
optimizer
=
tf
.
train
.
GradientDescentOptimizer
(
learning_rate
)
elif
optimizer_name
==
"adagrad"
:
optimizer
=
tf
.
train
.
AdagradOptimizer
(
learning_rate
)
elif
optimizer_name
==
"adam"
:
optimizer
=
tf
.
train
.
AdamOptimizer
(
learning_rate
)
elif
optimizer_name
==
"rmsprop"
:
optimizer
=
tf
.
RMSPropOptimizer
(
learning_rate
)
else
:
raise
ValueError
(
"Unknown optimizer: {}"
.
format
(
hparams
.
optimizer
))
if
use_tpu
:
optimizer
=
tf
.
contrib
.
tpu
.
CrossShardOptimizer
(
optimizer
)
return
optimizer
def
create_train_op
(
model
,
optimizer
):
"""Creates a Tensor to train the model.
Args:
model: Instance of AstroModel.
optimizer: Instance of tf.train.Optimizer.
Returns:
A Tensor that runs a single training step and returns model.total_loss.
"""
# Maybe clip gradient norms.
transform_grads_fn
=
None
if
model
.
hparams
.
get
(
"clip_grad_norm"
):
transform_grads_fn
=
tf
.
contrib
.
training
.
clip_gradient_norms_fn
(
model
.
hparams
.
clip_gradient_norm
)
# Create train op.
return
tf
.
contrib
.
training
.
create_train_op
(
total_loss
=
model
.
total_loss
,
optimizer
=
optimizer
,
global_step
=
model
.
global_step
,
transform_grads_fn
=
transform_grads_fn
)
research/astronet/astronet/predict.py
deleted
100644 → 0
View file @
17c2f0cc
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generate predictions for a Threshold Crossing Event using a trained model."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
argparse
import
sys
import
matplotlib.pyplot
as
plt
import
numpy
as
np
import
tensorflow
as
tf
from
astronet
import
models
from
astronet.data
import
preprocess
from
astronet.util
import
estimator_util
from
tf_util
import
config_util
from
tf_util
import
configdict
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--model"
,
type
=
str
,
required
=
True
,
help
=
"Name of the model class."
)
parser
.
add_argument
(
"--config_name"
,
type
=
str
,
help
=
"Name of the model and training configuration. Exactly one of "
"--config_name or --config_json is required."
)
parser
.
add_argument
(
"--config_json"
,
type
=
str
,
help
=
"JSON string or JSON file containing the model and training "
"configuration. Exactly one of --config_name or --config_json is required."
)
parser
.
add_argument
(
"--model_dir"
,
type
=
str
,
required
=
True
,
help
=
"Directory containing a model checkpoint."
)
parser
.
add_argument
(
"--kepler_data_dir"
,
type
=
str
,
required
=
True
,
help
=
"Base folder containing Kepler data."
)
parser
.
add_argument
(
"--kepler_id"
,
type
=
int
,
required
=
True
,
help
=
"Kepler ID of the target star."
)
parser
.
add_argument
(
"--period"
,
type
=
float
,
required
=
True
,
help
=
"Period of the TCE, in days."
)
parser
.
add_argument
(
"--t0"
,
type
=
float
,
required
=
True
,
help
=
"Epoch of the TCE."
)
parser
.
add_argument
(
"--duration"
,
type
=
float
,
required
=
True
,
help
=
"Duration of the TCE, in days."
)
parser
.
add_argument
(
"--output_image_file"
,
type
=
str
,
help
=
"If specified, path to an output image file containing feature plots. "
"Must end in a valid image extension, e.g. png."
)
def
_process_tce
(
feature_config
):
"""Reads and process the input features of a Threshold Crossing Event.
Args:
feature_config: ConfigDict containing the feature configurations.
Returns:
A dictionary of processed light curve features.
Raises:
ValueError: If feature_config contains features other than 'global_view'
and 'local_view'.
"""
if
not
{
"global_view"
,
"local_view"
}.
issuperset
(
feature_config
.
keys
()):
raise
ValueError
(
"Only 'global_view' and 'local_view' features are supported."
)
# Read and process the light curve.
all_time
,
all_flux
=
preprocess
.
read_light_curve
(
FLAGS
.
kepler_id
,
FLAGS
.
kepler_data_dir
)
time
,
flux
=
preprocess
.
process_light_curve
(
all_time
,
all_flux
)
time
,
flux
=
preprocess
.
phase_fold_and_sort_light_curve
(
time
,
flux
,
FLAGS
.
period
,
FLAGS
.
t0
)
# Generate the local and global views.
features
=
{}
if
"global_view"
in
feature_config
:
global_view
=
preprocess
.
global_view
(
time
,
flux
,
FLAGS
.
period
)
# Add a batch dimension.
features
[
"global_view"
]
=
np
.
expand_dims
(
global_view
,
0
)
if
"local_view"
in
feature_config
:
local_view
=
preprocess
.
local_view
(
time
,
flux
,
FLAGS
.
period
,
FLAGS
.
duration
)
# Add a batch dimension.
features
[
"local_view"
]
=
np
.
expand_dims
(
local_view
,
0
)
# Possibly save plots.
if
FLAGS
.
output_image_file
:
ncols
=
len
(
features
)
fig
,
axes
=
plt
.
subplots
(
1
,
ncols
,
figsize
=
(
10
*
ncols
,
5
),
squeeze
=
False
)
for
i
,
name
in
enumerate
(
sorted
(
features
)):
ax
=
axes
[
0
][
i
]
ax
.
plot
(
features
[
name
][
0
],
"."
)
ax
.
set_title
(
name
)
ax
.
set_xlabel
(
"Bucketized Time (days)"
)
ax
.
set_ylabel
(
"Normalized Flux"
)
fig
.
tight_layout
()
fig
.
savefig
(
FLAGS
.
output_image_file
,
bbox_inches
=
"tight"
)
return
features
def
main
(
_
):
model_class
=
models
.
get_model_class
(
FLAGS
.
model
)
# Look up the model configuration.
assert
(
FLAGS
.
config_name
is
None
)
!=
(
FLAGS
.
config_json
is
None
),
(
"Exactly one of --config_name or --config_json is required."
)
config
=
(
models
.
get_model_config
(
FLAGS
.
model
,
FLAGS
.
config_name
)
if
FLAGS
.
config_name
else
config_util
.
parse_json
(
FLAGS
.
config_json
))
config
=
configdict
.
ConfigDict
(
config
)
# Create the estimator.
estimator
=
estimator_util
.
create_estimator
(
model_class
,
config
.
hparams
,
model_dir
=
FLAGS
.
model_dir
)
# Read and process the input features.
features
=
_process_tce
(
config
.
inputs
.
features
)
# Create an input function.
def
input_fn
():
return
tf
.
data
.
Dataset
.
from_tensors
({
"time_series_features"
:
features
})
# Generate the predictions.
for
predictions
in
estimator
.
predict
(
input_fn
):
assert
len
(
predictions
)
==
1
print
(
"Prediction:"
,
predictions
[
0
])
if
__name__
==
"__main__"
:
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
FLAGS
,
unparsed
=
parser
.
parse_known_args
()
tf
.
app
.
run
(
main
=
main
,
argv
=
[
sys
.
argv
[
0
]]
+
unparsed
)
research/astronet/astronet/train.py
deleted
100644 → 0
View file @
17c2f0cc
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Script for training an AstroNet model."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
argparse
import
sys
import
tensorflow
as
tf
from
astronet
import
models
from
astronet.util
import
estimator_util
from
tf_util
import
config_util
from
tf_util
import
configdict
from
tf_util
import
estimator_runner
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--model"
,
type
=
str
,
required
=
True
,
help
=
"Name of the model class."
)
parser
.
add_argument
(
"--config_name"
,
type
=
str
,
help
=
"Name of the model and training configuration. Exactly one of "
"--config_name or --config_json is required."
)
parser
.
add_argument
(
"--config_json"
,
type
=
str
,
help
=
"JSON string or JSON file containing the model and training "
"configuration. Exactly one of --config_name or --config_json is required."
)
parser
.
add_argument
(
"--train_files"
,
type
=
str
,
required
=
True
,
help
=
"Comma-separated list of file patterns matching the TFRecord files in "
"the training dataset."
)
parser
.
add_argument
(
"--eval_files"
,
type
=
str
,
help
=
"Comma-separated list of file patterns matching the TFRecord files in "
"the validation dataset."
)
parser
.
add_argument
(
"--model_dir"
,
type
=
str
,
required
=
True
,
help
=
"Directory for model checkpoints and summaries."
)
parser
.
add_argument
(
"--train_steps"
,
type
=
int
,
default
=
625
,
help
=
"Total number of steps to train the model for."
)
parser
.
add_argument
(
"--shuffle_buffer_size"
,
type
=
int
,
default
=
15000
,
help
=
"Size of the shuffle buffer for the training dataset."
)
def
main
(
_
):
model_class
=
models
.
get_model_class
(
FLAGS
.
model
)
# Look up the model configuration.
assert
(
FLAGS
.
config_name
is
None
)
!=
(
FLAGS
.
config_json
is
None
),
(
"Exactly one of --config_name or --config_json is required."
)
config
=
(
models
.
get_model_config
(
FLAGS
.
model
,
FLAGS
.
config_name
)
if
FLAGS
.
config_name
else
config_util
.
parse_json
(
FLAGS
.
config_json
))
config
=
configdict
.
ConfigDict
(
config
)
config_util
.
log_and_save_config
(
config
,
FLAGS
.
model_dir
)
# Create the estimator.
run_config
=
tf
.
estimator
.
RunConfig
(
keep_checkpoint_max
=
1
)
estimator
=
estimator_util
.
create_estimator
(
model_class
,
config
.
hparams
,
run_config
,
FLAGS
.
model_dir
)
# Create an input function that reads the training dataset. We iterate through
# the dataset once at a time if we are alternating with evaluation, otherwise
# we iterate infinitely.
train_input_fn
=
estimator_util
.
create_input_fn
(
file_pattern
=
FLAGS
.
train_files
,
input_config
=
config
.
inputs
,
mode
=
tf
.
estimator
.
ModeKeys
.
TRAIN
,
shuffle_values_buffer
=
FLAGS
.
shuffle_buffer_size
,
repeat
=
1
if
FLAGS
.
eval_files
else
None
)
if
not
FLAGS
.
eval_files
:
estimator
.
train
(
train_input_fn
,
max_steps
=
FLAGS
.
train_steps
)
else
:
eval_input_fn
=
estimator_util
.
create_input_fn
(
file_pattern
=
FLAGS
.
eval_files
,
input_config
=
config
.
inputs
,
mode
=
tf
.
estimator
.
ModeKeys
.
EVAL
)
eval_args
=
{
"val"
:
(
eval_input_fn
,
None
)
# eval_name: (input_fn, eval_steps)
}
for
_
in
estimator_runner
.
continuous_train_and_eval
(
estimator
=
estimator
,
train_input_fn
=
train_input_fn
,
eval_args
=
eval_args
,
train_steps
=
FLAGS
.
train_steps
):
# continuous_train_and_eval() yields evaluation metrics after each
# training epoch. We don't do anything here.
pass
if
__name__
==
"__main__"
:
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
FLAGS
,
unparsed
=
parser
.
parse_known_args
()
tf
.
app
.
run
(
main
=
main
,
argv
=
[
sys
.
argv
[
0
]]
+
unparsed
)
research/astronet/astronet/util/BUILD
deleted
100644 → 0
View file @
17c2f0cc
package
(
default_visibility
=
[
"//visibility:public"
])
licenses
([
"notice"
])
# Apache 2.0
py_library
(
name
=
"estimator_util"
,
srcs
=
[
"estimator_util.py"
],
srcs_version
=
"PY2AND3"
,
deps
=
[
"//astronet/ops:dataset_ops"
,
"//astronet/ops:metrics"
,
"//astronet/ops:training"
,
],
)
research/astronet/astronet/util/__init__.py
deleted
100644 → 0
View file @
17c2f0cc
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
research/astronet/astronet/util/estimator_util.py
deleted
100644 → 0
View file @
17c2f0cc
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper functions for creating a TensorFlow Estimator."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
copy
import
tensorflow
as
tf
from
astronet.ops
import
dataset_ops
from
astronet.ops
import
metrics
from
astronet.ops
import
training
class
_InputFn
(
object
):
"""Class that acts as a callable input function for Estimator train / eval."""
def
__init__
(
self
,
file_pattern
,
input_config
,
mode
,
shuffle_values_buffer
=
0
,
repeat
=
1
):
"""Initializes the input function.
Args:
file_pattern: File pattern matching input TFRecord files, e.g.
"/tmp/train-?????-of-00100". May also be a comma-separated list of file
patterns.
input_config: ConfigDict containing feature and label specifications.
mode: A tf.estimator.ModeKeys.
shuffle_values_buffer: If > 0, shuffle examples using a buffer of this
size.
repeat: The number of times to repeat the dataset. If None or -1 the
elements will be repeated indefinitely.
"""
self
.
_file_pattern
=
file_pattern
self
.
_input_config
=
input_config
self
.
_mode
=
mode
self
.
_shuffle_values_buffer
=
shuffle_values_buffer
self
.
_repeat
=
repeat
def
__call__
(
self
,
config
,
params
):
"""Builds the input pipeline."""
# Infer whether this input_fn was called by Estimator or TPUEstimator using
# the config type.
use_tpu
=
isinstance
(
config
,
tf
.
contrib
.
tpu
.
RunConfig
)
mode
=
self
.
_mode
include_labels
=
(
mode
in
[
tf
.
estimator
.
ModeKeys
.
TRAIN
,
tf
.
estimator
.
ModeKeys
.
EVAL
])
reverse_time_series_prob
=
0.5
if
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
else
0
shuffle_filenames
=
(
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
)
dataset
=
dataset_ops
.
build_dataset
(
file_pattern
=
self
.
_file_pattern
,
input_config
=
self
.
_input_config
,
batch_size
=
params
[
"batch_size"
],
include_labels
=
include_labels
,
reverse_time_series_prob
=
reverse_time_series_prob
,
shuffle_filenames
=
shuffle_filenames
,
shuffle_values_buffer
=
self
.
_shuffle_values_buffer
,
repeat
=
self
.
_repeat
,
use_tpu
=
use_tpu
)
return
dataset
def
create_input_fn
(
file_pattern
,
input_config
,
mode
,
shuffle_values_buffer
=
0
,
repeat
=
1
):
"""Creates an input_fn that reads a dataset from sharded TFRecord files.
Args:
file_pattern: File pattern matching input TFRecord files, e.g.
"/tmp/train-?????-of-00100". May also be a comma-separated list of file
patterns.
input_config: ConfigDict containing feature and label specifications.
mode: A tf.estimator.ModeKeys.
shuffle_values_buffer: If > 0, shuffle examples using a buffer of this size.
repeat: The number of times to repeat the dataset. If None or -1 the
elements will be repeated indefinitely.
Returns:
A callable that builds the input pipeline and returns a tf.data.Dataset
object.
"""
return
_InputFn
(
file_pattern
,
input_config
,
mode
,
shuffle_values_buffer
,
repeat
)
class
_ModelFn
(
object
):
"""Class that acts as a callable model function for Estimator train / eval."""
def
__init__
(
self
,
model_class
,
hparams
,
use_tpu
=
False
):
"""Initializes the model function.
Args:
model_class: Model class.
hparams: ConfigDict containing hyperparameters for building and training
the model.
use_tpu: If True, a TPUEstimator will be returned. Otherwise an Estimator
will be returned.
"""
self
.
_model_class
=
model_class
self
.
_base_hparams
=
hparams
self
.
_use_tpu
=
use_tpu
def
__call__
(
self
,
features
,
labels
,
mode
,
params
):
"""Builds the model and returns an EstimatorSpec or TPUEstimatorSpec."""
hparams
=
copy
.
deepcopy
(
self
.
_base_hparams
)
if
"batch_size"
in
params
:
hparams
.
batch_size
=
params
[
"batch_size"
]
# Allow labels to be passed in the features dictionary.
if
"labels"
in
features
:
if
labels
is
not
None
and
labels
is
not
features
[
"labels"
]:
raise
ValueError
(
"Conflicting labels: features['labels'] = {}, labels = {}"
.
format
(
features
[
"labels"
],
labels
))
labels
=
features
.
pop
(
"labels"
)
model
=
self
.
_model_class
(
features
,
labels
,
hparams
,
mode
)
model
.
build
()
# Possibly create train_op.
use_tpu
=
self
.
_use_tpu
train_op
=
None
if
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
:
learning_rate
=
training
.
create_learning_rate
(
hparams
,
model
.
global_step
)
optimizer
=
training
.
create_optimizer
(
hparams
,
learning_rate
,
use_tpu
)
train_op
=
training
.
create_train_op
(
model
,
optimizer
)
# Possibly create evaluation metrics.
eval_metrics
=
None
if
mode
==
tf
.
estimator
.
ModeKeys
.
EVAL
:
eval_metrics
=
(
metrics
.
create_metric_fn
(
model
)
if
use_tpu
else
metrics
.
create_metrics
(
model
))
if
use_tpu
:
estimator
=
tf
.
contrib
.
tpu
.
TPUEstimatorSpec
(
mode
=
mode
,
predictions
=
model
.
predictions
,
loss
=
model
.
total_loss
,
train_op
=
train_op
,
eval_metrics
=
eval_metrics
)
else
:
estimator
=
tf
.
estimator
.
EstimatorSpec
(
mode
=
mode
,
predictions
=
model
.
predictions
,
loss
=
model
.
total_loss
,
train_op
=
train_op
,
eval_metric_ops
=
eval_metrics
)
return
estimator
def
create_model_fn
(
model_class
,
hparams
,
use_tpu
=
False
):
"""Wraps model_class as an Estimator or TPUEstimator model_fn.
Args:
model_class: AstroModel or a subclass.
hparams: ConfigDict of configuration parameters for building the model.
use_tpu: If True, a TPUEstimator model_fn is returned. Otherwise an
Estimator model_fn is returned.
Returns:
model_fn: A callable that constructs the model and returns a
TPUEstimatorSpec if use_tpu is True, otherwise an EstimatorSpec.
"""
return
_ModelFn
(
model_class
,
hparams
,
use_tpu
)
def
create_estimator
(
model_class
,
hparams
,
run_config
=
None
,
model_dir
=
None
,
eval_batch_size
=
None
):
"""Wraps model_class as an Estimator or TPUEstimator.
If run_config is None or a tf.estimator.RunConfig, an Estimator is returned.
If run_config is a tf.contrib.tpu.RunConfig, a TPUEstimator is returned.
Args:
model_class: AstroModel or a subclass.
hparams: ConfigDict of configuration parameters for building the model.
run_config: Optional tf.estimator.RunConfig or tf.contrib.tpu.RunConfig.
model_dir: Optional directory for saving the model. If not passed
explicitly, it must be specified in run_config.
eval_batch_size: Optional batch size for evaluation on TPU. Only applicable
if run_config is a tf.contrib.tpu.RunConfig. Defaults to
hparams.batch_size.
Returns:
An Estimator object if run_config is None or a tf.estimator.RunConfig, or a
TPUEstimator object if run_config is a tf.contrib.tpu.RunConfig.
Raises:
ValueError:
If model_dir is not passed explicitly or in run_config.model_dir, or if
eval_batch_size is specified and run_config is not a
tf.contrib.tpu.RunConfig.
"""
if
run_config
is
None
:
run_config
=
tf
.
estimator
.
RunConfig
()
else
:
run_config
=
copy
.
deepcopy
(
run_config
)
if
not
model_dir
and
not
run_config
.
model_dir
:
raise
ValueError
(
"model_dir must be passed explicitly or specified in run_config"
)
use_tpu
=
isinstance
(
run_config
,
tf
.
contrib
.
tpu
.
RunConfig
)
model_fn
=
create_model_fn
(
model_class
,
hparams
,
use_tpu
)
if
use_tpu
:
eval_batch_size
=
eval_batch_size
or
hparams
.
batch_size
estimator
=
tf
.
contrib
.
tpu
.
TPUEstimator
(
model_fn
=
model_fn
,
model_dir
=
model_dir
,
config
=
run_config
,
train_batch_size
=
hparams
.
batch_size
,
eval_batch_size
=
eval_batch_size
)
else
:
if
eval_batch_size
is
not
None
:
raise
ValueError
(
"eval_batch_size can only be specified for TPU."
)
estimator
=
tf
.
estimator
.
Estimator
(
model_fn
=
model_fn
,
model_dir
=
model_dir
,
config
=
run_config
,
params
=
{
"batch_size"
:
hparams
.
batch_size
})
return
estimator
research/astronet/astrowavenet/BUILD
deleted
100644 → 0
View file @
17c2f0cc
"""A TensorFlow model for generative modeling of light curves."""
package
(
default_visibility
=
[
"//visibility:public"
])
licenses
([
"notice"
])
# Apache 2.0
py_binary
(
name
=
"trainer"
,
srcs
=
[
"trainer.py"
],
srcs_version
=
"PY2AND3"
,
deps
=
[
":astrowavenet_model"
,
":configurations"
,
"//astrowavenet/data:kepler_light_curves"
,
"//astrowavenet/data:synthetic_transits"
,
"//astrowavenet/util:estimator_util"
,
"//tf_util:config_util"
,
"//tf_util:configdict"
,
"//tf_util:estimator_runner"
,
],
)
py_library
(
name
=
"configurations"
,
srcs
=
[
"configurations.py"
],
srcs_version
=
"PY2AND3"
,
)
py_library
(
name
=
"astrowavenet_model"
,
srcs
=
[
"astrowavenet_model.py"
,
],
srcs_version
=
"PY2AND3"
,
)
py_test
(
name
=
"astrowavenet_model_test"
,
size
=
"small"
,
srcs
=
[
"astrowavenet_model_test.py"
,
],
srcs_version
=
"PY2AND3"
,
deps
=
[
":astrowavenet_model"
,
":configurations"
,
"//tf_util:configdict"
,
],
)
research/astronet/astrowavenet/README.md
deleted
100644 → 0
View file @
17c2f0cc
# AstroWaveNet: A generative model for light curves.
Implementation based on "WaveNet: A Generative Model of Raw Audio":
https://arxiv.org/abs/1609.03499
## Code Authors
Alex Tamkin:
[
@atamkin
](
https://github.com/atamkin
)
Chris Shallue:
[
@cshallue
](
https://github.com/cshallue
)
## Pull Requests / Issues
Chris Shallue:
[
@cshallue
](
https://github.com/cshallue
)
## Additional Dependencies
In addition to the
[
required packages
](
../README.md#required-packages
)
listed in
the top-level README, this package requires:
*
**TensorFlow 1.12 or greater**
(
[
instructions
](
https://www.tensorflow.org/install/
)
)
*
**TensorFlow Probability**
(
[
instructions
](
https://www.tensorflow.org/probability/install
)
)
*
**Six**
(
[
instructions
](
https://pypi.org/project/six/
)
)
## Basic Usage
To train a model on synthetic transits:
```
bash
bazel build astrowavenet/...
```
```
bash
bazel-bin/astrowavenet/trainer
\
--dataset
=
synthetic_transits
\
--model_dir
=
/tmp/astrowavenet/
\
--config_overrides
=
'{"hparams": {"batch_size": 16, "num_residual_blocks": 2}}'
\
--schedule
=
train_and_eval
\
--eval_steps
=
100
\
--save_checkpoints_steps
=
1000
```
research/astronet/astrowavenet/__init__.py
deleted
100644 → 0
View file @
17c2f0cc
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
research/astronet/astrowavenet/astrowavenet_model.py
deleted
100644 → 0
View file @
17c2f0cc
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A TensorFlow WaveNet model for generative modeling of light curves.
Implementation based on "WaveNet: A Generative Model of Raw Audio":
https://arxiv.org/abs/1609.03499
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
import
tensorflow_probability
as
tfp
def
_shift_right
(
x
):
"""Shifts the input Tensor right by one index along the second dimension.
Pads the front with zeros and discards the last element.
Args:
x: Input three-dimensional tf.Tensor.
Returns:
Padded, shifted tensor of same shape as input.
"""
x_padded
=
tf
.
pad
(
x
,
[[
0
,
0
],
[
1
,
0
],
[
0
,
0
]])
return
x_padded
[:,
:
-
1
,
:]
class
AstroWaveNet
(
object
):
"""A TensorFlow model for generative modeling of light curves."""
def
__init__
(
self
,
features
,
hparams
,
mode
):
"""Basic setup.
The actual TensorFlow graph is constructed in build().
Args:
features: A dictionary containing "autoregressive_input" and
"conditioning_stack", each of which is a named input Tensor. All
features have dtype float32 and shape [batch_size, length, dim].
hparams: A ConfigDict of hyperparameters for building the model.
mode: A tf.estimator.ModeKeys to specify whether the graph should be built
for training, evaluation or prediction.
Raises:
ValueError: If mode is invalid.
"""
valid_modes
=
[
tf
.
estimator
.
ModeKeys
.
TRAIN
,
tf
.
estimator
.
ModeKeys
.
EVAL
,
tf
.
estimator
.
ModeKeys
.
PREDICT
]
if
mode
not
in
valid_modes
:
raise
ValueError
(
"Expected mode in {}. Got: {}"
.
format
(
valid_modes
,
mode
))
self
.
hparams
=
hparams
self
.
mode
=
mode
self
.
autoregressive_input
=
features
[
"autoregressive_input"
]
self
.
conditioning_stack
=
features
[
"conditioning_stack"
]
self
.
weights
=
features
.
get
(
"weights"
)
self
.
network_output
=
None
# Sum of skip connections from dilation stack.
self
.
dist_params
=
None
# Dict of predicted distribution parameters.
self
.
predicted_distributions
=
None
# Predicted distribution for examples.
self
.
autoregressive_target
=
None
# Autoregressive target predictions.
self
.
batch_losses
=
None
# Loss for each predicted distribution in batch.
self
.
per_example_loss
=
None
# Loss for each example in batch.
self
.
num_nonzero_weight_examples
=
None
# Number of examples in batch.
self
.
total_loss
=
None
# Overall loss for the batch.
self
.
global_step
=
None
# Global step Tensor.
def
causal_conv_layer
(
self
,
x
,
output_size
,
kernel_width
,
dilation_rate
=
1
):
"""Applies a dialated causal convolution to the input.
Args:
x: tf.Tensor; Input tensor.
output_size: int; Number of output filters for the convolution.
kernel_width: int; Width of the 1D convolution window.
dilation_rate: int; Dilation rate of the layer.
Returns:
Resulting tf.Tensor after applying the convolution.
"""
causal_conv_op
=
tf
.
keras
.
layers
.
Conv1D
(
output_size
,
kernel_width
,
padding
=
"causal"
,
dilation_rate
=
dilation_rate
,
name
=
"causal_conv"
)
return
causal_conv_op
(
x
)
def
conv_1x1_layer
(
self
,
x
,
output_size
,
activation
=
None
):
"""Applies a 1x1 convolution to the input.
Args:
x: tf.Tensor; Input tensor.
output_size: int; Number of output filters for the 1x1 convolution.
activation: Activation function to apply (e.g. 'relu').
Returns:
Resulting tf.Tensor after applying the 1x1 convolution.
"""
conv_1x1_op
=
tf
.
keras
.
layers
.
Conv1D
(
output_size
,
1
,
activation
=
activation
,
name
=
"conv1x1"
)
return
conv_1x1_op
(
x
)
def
gated_residual_layer
(
self
,
x
,
dilation_rate
):
"""Creates a gated, dilated convolutional layer with a residual connnection.
Args:
x: tf.Tensor; Input tensor
dilation_rate: int; Dilation rate of the layer.
Returns:
skip_connection: tf.Tensor; Skip connection to network_output layer.
residual_connection: tf.Tensor; Sum of learned residual and input tensor.
"""
with
tf
.
variable_scope
(
"filter"
):
x_filter_conv
=
self
.
causal_conv_layer
(
x
,
x
.
shape
[
-
1
].
value
,
self
.
hparams
.
dilation_kernel_width
,
dilation_rate
)
cond_filter_conv
=
self
.
conv_1x1_layer
(
self
.
conditioning_stack
,
x
.
shape
[
-
1
].
value
)
with
tf
.
variable_scope
(
"gate"
):
x_gate_conv
=
self
.
causal_conv_layer
(
x
,
x
.
shape
[
-
1
].
value
,
self
.
hparams
.
dilation_kernel_width
,
dilation_rate
)
cond_gate_conv
=
self
.
conv_1x1_layer
(
self
.
conditioning_stack
,
x
.
shape
[
-
1
].
value
)
gated_activation
=
(
tf
.
tanh
(
x_filter_conv
+
cond_filter_conv
)
*
tf
.
sigmoid
(
x_gate_conv
+
cond_gate_conv
))
with
tf
.
variable_scope
(
"residual"
):
residual
=
self
.
conv_1x1_layer
(
gated_activation
,
x
.
shape
[
-
1
].
value
)
with
tf
.
variable_scope
(
"skip"
):
skip_connection
=
self
.
conv_1x1_layer
(
gated_activation
,
self
.
hparams
.
skip_output_dim
)
return
skip_connection
,
x
+
residual
def
build_network
(
self
):
"""Builds WaveNet network.
This consists of:
1) An initial causal convolution,
2) The dialation stack, and
3) Summing of skip connections
The network output can then be used to predict various output distributions.
Inputs:
self.autoregressive_input
self.conditioning_stack
Outputs:
self.network_output; tf.Tensor
"""
skip_connections
=
[]
x
=
_shift_right
(
self
.
autoregressive_input
)
with
tf
.
variable_scope
(
"preprocess"
):
x
=
self
.
causal_conv_layer
(
x
,
self
.
hparams
.
preprocess_output_size
,
self
.
hparams
.
preprocess_kernel_width
)
for
i
in
range
(
self
.
hparams
.
num_residual_blocks
):
with
tf
.
variable_scope
(
"block_{}"
.
format
(
i
)):
for
dilation_rate
in
self
.
hparams
.
dilation_rates
:
with
tf
.
variable_scope
(
"dilation_{}"
.
format
(
dilation_rate
)):
skip_connection
,
x
=
self
.
gated_residual_layer
(
x
,
dilation_rate
)
skip_connections
.
append
(
skip_connection
)
self
.
network_output
=
tf
.
add_n
(
skip_connections
)
def
dist_params_layer
(
self
,
x
,
outputs_size
):
"""Converts x to the correct shape for populating a distribution object.
Args:
x: A Tensor of shape [batch_size, time_series_length, num_features].
outputs_size: The number of parameters needed to specify all the
distributions in the output. E.g. 5*3=15 to specify 5 distributions with
3 parameters each.
Returns:
The parameters of each distribution, a tensor of shape [batch_size,
time_series_length, outputs_size].
"""
with
tf
.
variable_scope
(
"dist_params"
):
conv_outputs
=
self
.
conv_1x1_layer
(
x
,
outputs_size
)
return
conv_outputs
def
build_predictions
(
self
):
"""Predicts output distribution from network outputs.
Runs the model through:
1) ReLU
2) 1x1 convolution
3) ReLU
4) 1x1 convolution
The result of the last convolution is used as the parameters of the
specified output distribution (currently either Categorical or Normal).
Inputs:
self.network_outputs
Outputs:
self.dist_params
self.predicted_distributions
Raises:
ValueError: If distribution type is neither 'categorical' nor 'normal'.
"""
with
tf
.
variable_scope
(
"postprocess"
):
network_output
=
tf
.
keras
.
activations
.
relu
(
self
.
network_output
)
network_output
=
self
.
conv_1x1_layer
(
network_output
,
output_size
=
network_output
.
shape
[
-
1
].
value
,
activation
=
"relu"
)
num_dists
=
self
.
autoregressive_input
.
shape
[
-
1
].
value
if
self
.
hparams
.
output_distribution
.
type
==
"categorical"
:
num_classes
=
self
.
hparams
.
output_distribution
.
num_classes
logits
=
self
.
dist_params_layer
(
network_output
,
num_dists
*
num_classes
)
logits_shape
=
tf
.
concat
(
[
tf
.
shape
(
network_output
)[:
-
1
],
[
num_dists
,
num_classes
]],
0
)
logits
=
tf
.
reshape
(
logits
,
logits_shape
)
dist
=
tfp
.
distributions
.
Categorical
(
logits
=
logits
)
dist_params
=
{
"logits"
:
logits
}
elif
self
.
hparams
.
output_distribution
.
type
==
"normal"
:
loc_scale
=
self
.
dist_params_layer
(
network_output
,
num_dists
*
2
)
loc
,
scale
=
tf
.
split
(
loc_scale
,
2
,
axis
=-
1
)
# Ensure scale is positive.
scale
=
tf
.
nn
.
softplus
(
scale
)
+
self
.
hparams
.
output_distribution
.
min_scale
dist
=
tfp
.
distributions
.
Normal
(
loc
,
scale
)
dist_params
=
{
"loc"
:
loc
,
"scale"
:
scale
}
else
:
raise
ValueError
(
"Unsupported distribution type {}"
.
format
(
self
.
hparams
.
output_distribution
.
type
))
self
.
dist_params
=
dist_params
self
.
predicted_distributions
=
dist
def
build_losses
(
self
):
"""Builds the training losses.
Inputs:
self.predicted_distributions
Outputs:
self.batch_losses
self.total_loss
"""
autoregressive_target
=
self
.
autoregressive_input
# Quantize the target if the output distribution is categorical.
if
self
.
hparams
.
output_distribution
.
type
==
"categorical"
:
min_val
=
self
.
hparams
.
output_distribution
.
min_quantization_value
max_val
=
self
.
hparams
.
output_distribution
.
max_quantization_value
num_classes
=
self
.
hparams
.
output_distribution
.
num_classes
clipped_target
=
tf
.
keras
.
backend
.
clip
(
autoregressive_target
,
min_val
,
max_val
)
quantized_target
=
tf
.
floor
(
(
clipped_target
-
min_val
)
/
(
max_val
-
min_val
)
*
num_classes
)
# Deal with the corner case where clipped_target equals max_val by mapping
# the label num_classes to num_classes - 1. Essentially, this makes the
# final quantized bucket a closed interval while all the other quantized
# buckets are half-open intervals.
quantized_target
=
tf
.
where
(
quantized_target
>=
num_classes
,
tf
.
ones_like
(
quantized_target
)
*
(
num_classes
-
1
),
quantized_target
)
autoregressive_target
=
quantized_target
log_prob
=
self
.
predicted_distributions
.
log_prob
(
autoregressive_target
)
weights
=
self
.
weights
if
weights
is
None
:
weights
=
tf
.
ones_like
(
log_prob
)
weights_dim
=
len
(
weights
.
shape
)
per_example_weight
=
tf
.
reduce_sum
(
weights
,
axis
=
list
(
range
(
1
,
weights_dim
)))
per_example_indicator
=
tf
.
to_float
(
tf
.
greater
(
per_example_weight
,
0
))
num_examples
=
tf
.
reduce_sum
(
per_example_indicator
)
batch_losses
=
-
log_prob
*
weights
losses_ndims
=
batch_losses
.
shape
.
ndims
per_example_loss_sum
=
tf
.
reduce_sum
(
batch_losses
,
axis
=
list
(
range
(
1
,
losses_ndims
)))
per_example_loss
=
tf
.
where
(
per_example_weight
>
0
,
per_example_loss_sum
/
per_example_weight
,
tf
.
zeros_like
(
per_example_weight
))
total_loss
=
tf
.
reduce_sum
(
per_example_loss
)
/
num_examples
self
.
autoregressive_target
=
autoregressive_target
self
.
batch_losses
=
batch_losses
self
.
per_example_loss
=
per_example_loss
self
.
num_nonzero_weight_examples
=
num_examples
self
.
total_loss
=
total_loss
def
build
(
self
):
"""Creates all ops for training, evaluation or inference."""
self
.
global_step
=
tf
.
train
.
get_or_create_global_step
()
self
.
build_network
()
self
.
build_predictions
()
if
self
.
mode
in
[
tf
.
estimator
.
ModeKeys
.
TRAIN
,
tf
.
estimator
.
ModeKeys
.
EVAL
]:
self
.
build_losses
()
research/astronet/astrowavenet/astrowavenet_model_test.py
deleted
100644 → 0
View file @
17c2f0cc
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for astrowavenet."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
tensorflow
as
tf
from
astrowavenet
import
astrowavenet_model
from
tf_util
import
configdict
class
AstrowavenetTest
(
tf
.
test
.
TestCase
):
def
assertShapeEquals
(
self
,
shape
,
tensor_or_array
):
"""Asserts that a Tensor or Numpy array has the expected shape.
Args:
shape: Numpy array or anything that can be converted to one.
tensor_or_array: tf.Tensor, tf.Variable, or Numpy array.
"""
if
isinstance
(
tensor_or_array
,
(
np
.
ndarray
,
np
.
generic
)):
self
.
assertAllEqual
(
shape
,
tensor_or_array
.
shape
)
elif
isinstance
(
tensor_or_array
,
(
tf
.
Tensor
,
tf
.
Variable
)):
self
.
assertAllEqual
(
shape
,
tensor_or_array
.
shape
.
as_list
())
else
:
raise
TypeError
(
"tensor_or_array must be a Tensor or Numpy ndarray"
)
def
test_build_model
(
self
):
time_series_length
=
9
input_num_features
=
8
context_num_features
=
7
input_placeholder
=
tf
.
placeholder
(
dtype
=
tf
.
float32
,
shape
=
[
None
,
time_series_length
,
input_num_features
],
name
=
"input"
)
context_placeholder
=
tf
.
placeholder
(
dtype
=
tf
.
float32
,
shape
=
[
None
,
time_series_length
,
context_num_features
],
name
=
"context"
)
features
=
{
"autoregressive_input"
:
input_placeholder
,
"conditioning_stack"
:
context_placeholder
}
mode
=
tf
.
estimator
.
ModeKeys
.
TRAIN
hparams
=
configdict
.
ConfigDict
({
"dilation_kernel_width"
:
2
,
"skip_output_dim"
:
6
,
"preprocess_output_size"
:
3
,
"preprocess_kernel_width"
:
5
,
"num_residual_blocks"
:
2
,
"dilation_rates"
:
[
1
,
2
,
4
],
"output_distribution"
:
{
"type"
:
"normal"
,
"min_scale"
:
0.001
,
}
})
model
=
astrowavenet_model
.
AstroWaveNet
(
features
,
hparams
,
mode
)
model
.
build
()
variables
=
{
v
.
op
.
name
:
v
for
v
in
tf
.
trainable_variables
()}
# Verify variable shapes in two residual blocks.
var
=
variables
[
"preprocess/causal_conv/kernel"
]
self
.
assertShapeEquals
((
5
,
8
,
3
),
var
)
var
=
variables
[
"preprocess/causal_conv/bias"
]
self
.
assertShapeEquals
((
3
,),
var
)
var
=
variables
[
"block_0/dilation_1/filter/causal_conv/kernel"
]
self
.
assertShapeEquals
((
2
,
3
,
3
),
var
)
var
=
variables
[
"block_0/dilation_1/filter/causal_conv/bias"
]
self
.
assertShapeEquals
((
3
,),
var
)
var
=
variables
[
"block_0/dilation_1/filter/conv1x1/kernel"
]
self
.
assertShapeEquals
((
1
,
7
,
3
),
var
)
var
=
variables
[
"block_0/dilation_1/filter/conv1x1/bias"
]
self
.
assertShapeEquals
((
3
,),
var
)
var
=
variables
[
"block_0/dilation_1/gate/causal_conv/kernel"
]
self
.
assertShapeEquals
((
2
,
3
,
3
),
var
)
var
=
variables
[
"block_0/dilation_1/gate/causal_conv/bias"
]
self
.
assertShapeEquals
((
3
,),
var
)
var
=
variables
[
"block_0/dilation_1/gate/conv1x1/kernel"
]
self
.
assertShapeEquals
((
1
,
7
,
3
),
var
)
var
=
variables
[
"block_0/dilation_1/gate/conv1x1/bias"
]
self
.
assertShapeEquals
((
3
,),
var
)
var
=
variables
[
"block_0/dilation_1/residual/conv1x1/kernel"
]
self
.
assertShapeEquals
((
1
,
3
,
3
),
var
)
var
=
variables
[
"block_0/dilation_1/residual/conv1x1/bias"
]
self
.
assertShapeEquals
((
3
,),
var
)
var
=
variables
[
"block_0/dilation_1/skip/conv1x1/kernel"
]
self
.
assertShapeEquals
((
1
,
3
,
6
),
var
)
var
=
variables
[
"block_0/dilation_1/skip/conv1x1/bias"
]
self
.
assertShapeEquals
((
6
,),
var
)
var
=
variables
[
"block_1/dilation_4/filter/causal_conv/kernel"
]
self
.
assertShapeEquals
((
2
,
3
,
3
),
var
)
var
=
variables
[
"block_1/dilation_4/filter/causal_conv/bias"
]
self
.
assertShapeEquals
((
3
,),
var
)
var
=
variables
[
"block_1/dilation_4/filter/conv1x1/kernel"
]
self
.
assertShapeEquals
((
1
,
7
,
3
),
var
)
var
=
variables
[
"block_1/dilation_4/filter/conv1x1/bias"
]
self
.
assertShapeEquals
((
3
,),
var
)
var
=
variables
[
"block_1/dilation_4/gate/causal_conv/kernel"
]
self
.
assertShapeEquals
((
2
,
3
,
3
),
var
)
var
=
variables
[
"block_1/dilation_4/gate/causal_conv/bias"
]
self
.
assertShapeEquals
((
3
,),
var
)
var
=
variables
[
"block_1/dilation_4/gate/conv1x1/kernel"
]
self
.
assertShapeEquals
((
1
,
7
,
3
),
var
)
var
=
variables
[
"block_1/dilation_4/gate/conv1x1/bias"
]
self
.
assertShapeEquals
((
3
,),
var
)
var
=
variables
[
"block_1/dilation_4/residual/conv1x1/kernel"
]
self
.
assertShapeEquals
((
1
,
3
,
3
),
var
)
var
=
variables
[
"block_1/dilation_4/residual/conv1x1/bias"
]
self
.
assertShapeEquals
((
3
,),
var
)
var
=
variables
[
"block_1/dilation_4/skip/conv1x1/kernel"
]
self
.
assertShapeEquals
((
1
,
3
,
6
),
var
)
var
=
variables
[
"block_1/dilation_4/skip/conv1x1/bias"
]
self
.
assertShapeEquals
((
6
,),
var
)
var
=
variables
[
"postprocess/conv1x1/kernel"
]
self
.
assertShapeEquals
((
1
,
6
,
6
),
var
)
var
=
variables
[
"postprocess/conv1x1/bias"
]
self
.
assertShapeEquals
((
6
,),
var
)
var
=
variables
[
"dist_params/conv1x1/kernel"
]
self
.
assertShapeEquals
((
1
,
6
,
16
),
var
)
var
=
variables
[
"dist_params/conv1x1/bias"
]
self
.
assertShapeEquals
((
16
,),
var
)
# Verify total number of trainable parameters.
num_preprocess_params
=
(
hparams
.
preprocess_kernel_width
*
input_num_features
*
hparams
.
preprocess_output_size
+
hparams
.
preprocess_output_size
)
num_gated_params
=
(
hparams
.
dilation_kernel_width
*
hparams
.
preprocess_output_size
*
hparams
.
preprocess_output_size
+
hparams
.
preprocess_output_size
+
1
*
context_num_features
*
hparams
.
preprocess_output_size
+
hparams
.
preprocess_output_size
)
*
2
num_residual_params
=
(
1
*
hparams
.
preprocess_output_size
*
hparams
.
preprocess_output_size
+
hparams
.
preprocess_output_size
)
num_skip_params
=
(
1
*
hparams
.
preprocess_output_size
*
hparams
.
skip_output_dim
+
hparams
.
skip_output_dim
)
num_block_params
=
(
num_gated_params
+
num_residual_params
+
num_skip_params
)
*
len
(
hparams
.
dilation_rates
)
*
hparams
.
num_residual_blocks
num_postprocess_params
=
(
1
*
hparams
.
skip_output_dim
*
hparams
.
skip_output_dim
+
hparams
.
skip_output_dim
)
num_dist_params
=
(
1
*
hparams
.
skip_output_dim
*
2
*
input_num_features
+
2
*
input_num_features
)
total_params
=
(
num_preprocess_params
+
num_block_params
+
num_postprocess_params
+
num_dist_params
)
total_retrieved_params
=
0
for
v
in
tf
.
trainable_variables
():
total_retrieved_params
+=
np
.
prod
(
v
.
shape
)
self
.
assertEqual
(
total_params
,
total_retrieved_params
)
# Verify model runs and outputs losses of correct shape.
scaffold
=
tf
.
train
.
Scaffold
()
scaffold
.
finalize
()
with
self
.
cached_session
()
as
sess
:
sess
.
run
([
scaffold
.
init_op
,
scaffold
.
local_init_op
])
step
=
sess
.
run
(
model
.
global_step
)
self
.
assertEqual
(
0
,
step
)
batch_size
=
11
feed_dict
=
{
input_placeholder
:
np
.
random
.
random
((
batch_size
,
time_series_length
,
input_num_features
)),
context_placeholder
:
np
.
random
.
random
((
batch_size
,
time_series_length
,
context_num_features
))
}
batch_losses
,
per_example_loss
,
total_loss
=
sess
.
run
(
[
model
.
batch_losses
,
model
.
per_example_loss
,
model
.
total_loss
],
feed_dict
=
feed_dict
)
self
.
assertShapeEquals
(
(
batch_size
,
time_series_length
,
input_num_features
),
batch_losses
)
self
.
assertShapeEquals
((
batch_size
,),
per_example_loss
)
self
.
assertShapeEquals
((),
total_loss
)
def
test_build_model_categorical
(
self
):
time_series_length
=
9
input_num_features
=
8
context_num_features
=
7
input_placeholder
=
tf
.
placeholder
(
dtype
=
tf
.
float32
,
shape
=
[
None
,
time_series_length
,
input_num_features
],
name
=
"input"
)
context_placeholder
=
tf
.
placeholder
(
dtype
=
tf
.
float32
,
shape
=
[
None
,
time_series_length
,
context_num_features
],
name
=
"context"
)
features
=
{
"autoregressive_input"
:
input_placeholder
,
"conditioning_stack"
:
context_placeholder
}
mode
=
tf
.
estimator
.
ModeKeys
.
TRAIN
hparams
=
configdict
.
ConfigDict
({
"dilation_kernel_width"
:
2
,
"skip_output_dim"
:
6
,
"preprocess_output_size"
:
3
,
"preprocess_kernel_width"
:
5
,
"num_residual_blocks"
:
2
,
"dilation_rates"
:
[
1
,
2
,
4
],
"output_distribution"
:
{
"type"
:
"categorical"
,
"num_classes"
:
256
,
"min_quantization_value"
:
-
1
,
"max_quantization_value"
:
1
}
})
model
=
astrowavenet_model
.
AstroWaveNet
(
features
,
hparams
,
mode
)
model
.
build
()
variables
=
{
v
.
op
.
name
:
v
for
v
in
tf
.
trainable_variables
()}
var
=
variables
[
"dist_params/conv1x1/kernel"
]
self
.
assertShapeEquals
(
(
1
,
hparams
.
skip_output_dim
,
hparams
.
output_distribution
.
num_classes
*
input_num_features
),
var
)
var
=
variables
[
"dist_params/conv1x1/bias"
]
self
.
assertShapeEquals
(
(
hparams
.
output_distribution
.
num_classes
*
input_num_features
,),
var
)
# Verify model runs and outputs losses of correct shape.
scaffold
=
tf
.
train
.
Scaffold
()
scaffold
.
finalize
()
with
self
.
cached_session
()
as
sess
:
sess
.
run
([
scaffold
.
init_op
,
scaffold
.
local_init_op
])
step
=
sess
.
run
(
model
.
global_step
)
self
.
assertEqual
(
0
,
step
)
batch_size
=
11
feed_dict
=
{
input_placeholder
:
np
.
random
.
random
((
batch_size
,
time_series_length
,
input_num_features
)),
context_placeholder
:
np
.
random
.
random
((
batch_size
,
time_series_length
,
context_num_features
))
}
batch_losses
,
per_example_loss
,
total_loss
=
sess
.
run
(
[
model
.
batch_losses
,
model
.
per_example_loss
,
model
.
total_loss
],
feed_dict
=
feed_dict
)
self
.
assertShapeEquals
(
(
batch_size
,
time_series_length
,
input_num_features
),
batch_losses
)
self
.
assertShapeEquals
((
batch_size
,),
per_example_loss
)
self
.
assertShapeEquals
((),
total_loss
)
def
test_output_normal
(
self
):
time_series_length
=
6
input_num_features
=
2
context_num_features
=
7
input_placeholder
=
tf
.
placeholder
(
dtype
=
tf
.
float32
,
shape
=
[
None
,
time_series_length
,
input_num_features
],
name
=
"input"
)
context_placeholder
=
tf
.
placeholder
(
dtype
=
tf
.
float32
,
shape
=
[
None
,
time_series_length
,
context_num_features
],
name
=
"context"
)
features
=
{
"autoregressive_input"
:
input_placeholder
,
"conditioning_stack"
:
context_placeholder
}
mode
=
tf
.
estimator
.
ModeKeys
.
TRAIN
hparams
=
configdict
.
ConfigDict
({
"dilation_kernel_width"
:
2
,
"skip_output_dim"
:
6
,
"preprocess_output_size"
:
3
,
"preprocess_kernel_width"
:
5
,
"num_residual_blocks"
:
2
,
"dilation_rates"
:
[
1
,
2
,
4
],
"output_distribution"
:
{
"type"
:
"normal"
,
"min_scale"
:
0
,
}
})
model
=
astrowavenet_model
.
AstroWaveNet
(
features
,
hparams
,
mode
)
model
.
build
()
# Model predicts loc and scale.
self
.
assertItemsEqual
([
"loc"
,
"scale"
],
model
.
dist_params
.
keys
())
self
.
assertShapeEquals
((
None
,
time_series_length
,
input_num_features
),
model
.
dist_params
[
"loc"
])
self
.
assertShapeEquals
((
None
,
time_series_length
,
input_num_features
),
model
.
dist_params
[
"scale"
])
scaffold
=
tf
.
train
.
Scaffold
()
scaffold
.
finalize
()
with
self
.
cached_session
()
as
sess
:
sess
.
run
([
scaffold
.
init_op
,
scaffold
.
local_init_op
])
step
=
sess
.
run
(
model
.
global_step
)
self
.
assertEqual
(
0
,
step
)
feed_dict
=
{
input_placeholder
:
[
[[
1
,
9
],
[
1
,
9
],
[
1
,
9
],
[
1
,
9
],
[
1
,
9
],
[
1
,
9
]],
[[
2
,
8
],
[
2
,
8
],
[
2
,
8
],
[
2
,
8
],
[
2
,
8
],
[
2
,
8
]],
],
# Context is not needed since we explicitly feed the dist params.
model
.
dist_params
[
"loc"
]:
[
[[
1
,
8
],
[
1
,
8
],
[
1
,
8
],
[
1
,
8
],
[
1
,
8
],
[
1
,
8
]],
[[
2
,
9
],
[
2
,
9
],
[
2
,
9
],
[
2
,
9
],
[
2
,
9
],
[
2
,
9
]],
],
model
.
dist_params
[
"scale"
]:
[
[[
0.1
,
0.1
],
[
0.2
,
0.2
],
[
0.5
,
0.5
],
[
1
,
1
],
[
2
,
2
],
[
5
,
5
]],
[[
0.1
,
0.1
],
[
0.2
,
0.2
],
[
0.5
,
0.5
],
[
1
,
1
],
[
2
,
2
],
[
5
,
5
]],
],
}
batch_losses
,
per_example_loss
,
num_examples
,
total_loss
=
sess
.
run
(
[
model
.
batch_losses
,
model
.
per_example_loss
,
model
.
num_nonzero_weight_examples
,
model
.
total_loss
],
feed_dict
=
feed_dict
)
np
.
testing
.
assert_array_almost_equal
(
[[[
-
1.38364656
,
48.61635344
],
[
-
0.69049938
,
11.80950062
],
[
0.22579135
,
2.22579135
],
[
0.91893853
,
1.41893853
],
[
1.61208571
,
1.73708571
],
[
2.52837645
,
2.54837645
]],
[[
-
1.38364656
,
48.61635344
],
[
-
0.69049938
,
11.80950062
],
[
0.22579135
,
2.22579135
],
[
0.91893853
,
1.41893853
],
[
1.61208571
,
1.73708571
],
[
2.52837645
,
2.54837645
]]],
batch_losses
)
np
.
testing
.
assert_array_almost_equal
([
5.96392435
,
5.96392435
],
per_example_loss
)
np
.
testing
.
assert_almost_equal
(
2
,
num_examples
)
np
.
testing
.
assert_almost_equal
(
5.96392435
,
total_loss
)
def
test_output_categorical
(
self
):
time_series_length
=
3
input_num_features
=
1
context_num_features
=
7
num_classes
=
4
# For quantized categorical output predictions.
input_placeholder
=
tf
.
placeholder
(
dtype
=
tf
.
float32
,
shape
=
[
None
,
time_series_length
,
input_num_features
],
name
=
"input"
)
context_placeholder
=
tf
.
placeholder
(
dtype
=
tf
.
float32
,
shape
=
[
None
,
time_series_length
,
context_num_features
],
name
=
"context"
)
features
=
{
"autoregressive_input"
:
input_placeholder
,
"conditioning_stack"
:
context_placeholder
}
mode
=
tf
.
estimator
.
ModeKeys
.
TRAIN
hparams
=
configdict
.
ConfigDict
({
"dilation_kernel_width"
:
2
,
"skip_output_dim"
:
6
,
"preprocess_output_size"
:
3
,
"preprocess_kernel_width"
:
5
,
"num_residual_blocks"
:
2
,
"dilation_rates"
:
[
1
,
2
,
4
],
"output_distribution"
:
{
"type"
:
"categorical"
,
"min_scale"
:
0
,
"num_classes"
:
num_classes
,
"min_quantization_value"
:
0
,
"max_quantization_value"
:
1
}
})
model
=
astrowavenet_model
.
AstroWaveNet
(
features
,
hparams
,
mode
)
model
.
build
()
self
.
assertItemsEqual
([
"logits"
],
model
.
dist_params
.
keys
())
self
.
assertShapeEquals
(
(
None
,
time_series_length
,
input_num_features
,
num_classes
),
model
.
dist_params
[
"logits"
])
scaffold
=
tf
.
train
.
Scaffold
()
scaffold
.
finalize
()
with
self
.
cached_session
()
as
sess
:
sess
.
run
([
scaffold
.
init_op
,
scaffold
.
local_init_op
])
step
=
sess
.
run
(
model
.
global_step
)
self
.
assertEqual
(
0
,
step
)
feed_dict
=
{
input_placeholder
:
[
[[
0
],
[
0
],
[
0
]],
# min_quantization_value
[[
0.2
],
[
0.2
],
[
0.2
]],
# Within bucket.
[[
0.25
],
[
0.25
],
[
0.25
]],
# On bucket boundary.
[[
0.5
],
[
0.5
],
[
0.5
]],
# On bucket boundary.
[[
0.8
],
[
0.8
],
[
0.8
]],
# Within bucket.
[[
1
],
[
1
],
[
1
]],
# max_quantization_value
[[
-
0.1
],
[
1.5
],
[
200
]],
# Outside range: will be clipped.
],
# Context is not needed since we explicitly feed the dist params.
model
.
dist_params
[
"logits"
]:
[
[[[
1
,
0
,
0
,
0
]],
[[
0
,
1
,
0
,
0
]],
[[
0
,
0
,
0
,
1
]]],
[[[
1
,
0
,
0
,
0
]],
[[
0
,
1
,
0
,
0
]],
[[
0
,
0
,
0
,
1
]]],
[[[
0
,
1
,
0
,
0
]],
[[
1
,
0
,
0
,
0
]],
[[
0
,
0
,
1
,
0
]]],
[[[
0
,
0
,
1
,
0
]],
[[
0
,
1
,
0
,
0
]],
[[
0
,
0
,
0
,
1
]]],
[[[
0
,
0
,
0
,
1
]],
[[
1
,
0
,
0
,
0
]],
[[
1
,
0
,
0
,
0
]]],
[[[
0
,
0
,
0
,
1
]],
[[
0
,
1
,
0
,
0
]],
[[
0
,
0
,
1
,
0
]]],
[[[
1
,
0
,
0
,
0
]],
[[
0
,
0
,
1
,
0
]],
[[
0
,
1
,
0
,
0
]]],
],
}
(
target
,
batch_losses
,
per_example_loss
,
num_examples
,
total_loss
)
=
sess
.
run
([
model
.
autoregressive_target
,
model
.
batch_losses
,
model
.
per_example_loss
,
model
.
num_nonzero_weight_examples
,
model
.
total_loss
],
feed_dict
=
feed_dict
)
np
.
testing
.
assert_array_almost_equal
([
[[
0
],
[
0
],
[
0
]],
[[
0
],
[
0
],
[
0
]],
[[
1
],
[
1
],
[
1
]],
[[
2
],
[
2
],
[
2
]],
[[
3
],
[
3
],
[
3
]],
[[
3
],
[
3
],
[
3
]],
[[
0
],
[
3
],
[
3
]],
],
target
)
np
.
testing
.
assert_array_almost_equal
([
[[
0.74366838
],
[
1.74366838
],
[
1.74366838
]],
[[
0.74366838
],
[
1.74366838
],
[
1.74366838
]],
[[
0.74366838
],
[
1.74366838
],
[
1.74366838
]],
[[
0.74366838
],
[
1.74366838
],
[
1.74366838
]],
[[
0.74366838
],
[
1.74366838
],
[
1.74366838
]],
[[
0.74366838
],
[
1.74366838
],
[
1.74366838
]],
[[
0.74366838
],
[
1.74366838
],
[
1.74366838
]],
],
batch_losses
)
np
.
testing
.
assert_array_almost_equal
([
1.41033504
,
1.41033504
,
1.41033504
,
1.41033504
,
1.41033504
,
1.41033504
,
1.41033504
],
per_example_loss
)
np
.
testing
.
assert_almost_equal
(
7
,
num_examples
)
np
.
testing
.
assert_almost_equal
(
1.41033504
,
total_loss
)
def
test_output_weighted
(
self
):
time_series_length
=
6
input_num_features
=
2
context_num_features
=
7
input_placeholder
=
tf
.
placeholder
(
dtype
=
tf
.
float32
,
shape
=
[
None
,
time_series_length
,
input_num_features
],
name
=
"input"
)
weights_placeholder
=
tf
.
placeholder
(
dtype
=
tf
.
float32
,
shape
=
[
None
,
time_series_length
,
input_num_features
],
name
=
"input"
)
context_placeholder
=
tf
.
placeholder
(
dtype
=
tf
.
float32
,
shape
=
[
None
,
time_series_length
,
context_num_features
],
name
=
"context"
)
features
=
{
"autoregressive_input"
:
input_placeholder
,
"weights"
:
weights_placeholder
,
"conditioning_stack"
:
context_placeholder
}
mode
=
tf
.
estimator
.
ModeKeys
.
TRAIN
hparams
=
configdict
.
ConfigDict
({
"dilation_kernel_width"
:
2
,
"skip_output_dim"
:
6
,
"preprocess_output_size"
:
3
,
"preprocess_kernel_width"
:
5
,
"num_residual_blocks"
:
2
,
"dilation_rates"
:
[
1
,
2
,
4
],
"output_distribution"
:
{
"type"
:
"normal"
,
"min_scale"
:
0
,
}
})
model
=
astrowavenet_model
.
AstroWaveNet
(
features
,
hparams
,
mode
)
model
.
build
()
scaffold
=
tf
.
train
.
Scaffold
()
scaffold
.
finalize
()
with
self
.
cached_session
()
as
sess
:
sess
.
run
([
scaffold
.
init_op
,
scaffold
.
local_init_op
])
step
=
sess
.
run
(
model
.
global_step
)
self
.
assertEqual
(
0
,
step
)
feed_dict
=
{
input_placeholder
:
[
[[
1
,
9
],
[
1
,
9
],
[
1
,
9
],
[
1
,
9
],
[
1
,
9
],
[
1
,
9
]],
[[
2
,
8
],
[
2
,
8
],
[
2
,
8
],
[
2
,
8
],
[
2
,
8
],
[
2
,
8
]],
[[
3
,
7
],
[
3
,
7
],
[
3
,
7
],
[
3
,
7
],
[
3
,
7
],
[
3
,
7
]],
],
weights_placeholder
:
[
[[
1
,
1
],
[
1
,
1
],
[
1
,
1
],
[
1
,
1
],
[
1
,
1
],
[
1
,
1
]],
[[
1
,
0
],
[
1
,
1
],
[
1
,
1
],
[
0
,
1
],
[
0
,
1
],
[
0
,
0
]],
[[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
]],
],
# Context is not needed since we explicitly feed the dist params.
model
.
dist_params
[
"loc"
]:
[
[[
1
,
8
],
[
1
,
8
],
[
1
,
8
],
[
1
,
8
],
[
1
,
8
],
[
1
,
8
]],
[[
2
,
9
],
[
2
,
9
],
[
2
,
9
],
[
2
,
9
],
[
2
,
9
],
[
2
,
9
]],
[[
3
,
6
],
[
3
,
6
],
[
3
,
6
],
[
3
,
6
],
[
3
,
6
],
[
3
,
6
]],
],
model
.
dist_params
[
"scale"
]:
[
[[
0.1
,
0.1
],
[
0.2
,
0.2
],
[
0.5
,
0.5
],
[
1
,
1
],
[
2
,
2
],
[
5
,
5
]],
[[
0.1
,
0.1
],
[
0.2
,
0.2
],
[
0.5
,
0.5
],
[
1
,
1
],
[
2
,
2
],
[
5
,
5
]],
[[
0.1
,
0.1
],
[
0.2
,
0.2
],
[
0.5
,
0.5
],
[
1
,
1
],
[
2
,
2
],
[
5
,
5
]],
],
}
batch_losses
,
per_example_loss
,
num_examples
,
total_loss
=
sess
.
run
(
[
model
.
batch_losses
,
model
.
per_example_loss
,
model
.
num_nonzero_weight_examples
,
model
.
total_loss
],
feed_dict
=
feed_dict
)
np
.
testing
.
assert_array_almost_equal
(
[[[
-
1.38364656
,
48.61635344
],
[
-
0.69049938
,
11.80950062
],
[
0.22579135
,
2.22579135
],
[
0.91893853
,
1.41893853
],
[
1.61208571
,
1.73708571
],
[
2.52837645
,
2.54837645
]],
[[
-
1.38364656
,
0
],
[
-
0.69049938
,
11.80950062
],
[
0.22579135
,
2.22579135
],
[
0
,
1.41893853
],
[
0
,
1.73708571
],
[
0
,
0
]],
[[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
]]],
batch_losses
)
np
.
testing
.
assert_array_almost_equal
([
5.96392435
,
2.19185166
,
0
],
per_example_loss
)
np
.
testing
.
assert_almost_equal
(
2
,
num_examples
)
np
.
testing
.
assert_almost_equal
(
4.07788801
,
total_loss
)
def
test_causality
(
self
):
time_series_length
=
7
input_num_features
=
1
context_num_features
=
1
input_placeholder
=
tf
.
placeholder
(
dtype
=
tf
.
float32
,
shape
=
[
None
,
time_series_length
,
input_num_features
],
name
=
"input"
)
context_placeholder
=
tf
.
placeholder
(
dtype
=
tf
.
float32
,
shape
=
[
None
,
time_series_length
,
context_num_features
],
name
=
"context"
)
features
=
{
"autoregressive_input"
:
input_placeholder
,
"conditioning_stack"
:
context_placeholder
}
mode
=
tf
.
estimator
.
ModeKeys
.
TRAIN
hparams
=
configdict
.
ConfigDict
({
"dilation_kernel_width"
:
1
,
"skip_output_dim"
:
1
,
"preprocess_output_size"
:
1
,
"preprocess_kernel_width"
:
1
,
"num_residual_blocks"
:
1
,
"dilation_rates"
:
[
1
],
"output_distribution"
:
{
"type"
:
"normal"
,
"min_scale"
:
0.001
,
}
})
model
=
astrowavenet_model
.
AstroWaveNet
(
features
,
hparams
,
mode
)
model
.
build
()
scaffold
=
tf
.
train
.
Scaffold
()
scaffold
.
finalize
()
with
self
.
cached_session
()
as
sess
:
sess
.
run
([
scaffold
.
init_op
,
scaffold
.
local_init_op
])
step
=
sess
.
run
(
model
.
global_step
)
self
.
assertEqual
(
0
,
step
)
feed_dict
=
{
input_placeholder
:
[
[[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
1
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
0
],
[
0
],
[
0
],
[
1
],
[
0
],
[
0
],
[
0
]],
[[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
1
]],
[[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
],
context_placeholder
:
[
[[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
1
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
0
],
[
0
],
[
0
],
[
1
],
[
0
],
[
0
],
[
0
]],
[[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
1
]],
],
}
network_output
=
sess
.
run
(
model
.
network_output
,
feed_dict
=
feed_dict
)
np
.
testing
.
assert_array_equal
(
[
[[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
# Input elements are used to predict the next timestamp.
[[
0
],
[
1
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
0
],
[
0
],
[
0
],
[
0
],
[
1
],
[
0
],
[
0
]],
[[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
# Context elements are used to predict the current timestamp.
[[
1
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
0
],
[
0
],
[
0
],
[
1
],
[
0
],
[
0
],
[
0
]],
[[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
1
]],
],
np
.
greater
(
np
.
abs
(
network_output
),
0
))
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
research/astronet/astrowavenet/configurations.py
deleted
100644 → 0
View file @
17c2f0cc
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configurations for model building, training and evaluation."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
def
base
():
"""Returns the base config for model building, training and evaluation."""
return
{
# Hyperparameters for building and training the model.
"hparams"
:
{
"batch_size"
:
64
,
"dilation_kernel_width"
:
2
,
"skip_output_dim"
:
10
,
"preprocess_output_size"
:
3
,
"preprocess_kernel_width"
:
10
,
"num_residual_blocks"
:
4
,
"dilation_rates"
:
[
1
,
2
,
4
,
8
,
16
],
"output_distribution"
:
{
"type"
:
"normal"
,
"min_scale"
:
0.001
},
# Learning rate parameters.
"learning_rate"
:
1e-6
,
"learning_rate_decay_steps"
:
0
,
"learning_rate_decay_factor"
:
0
,
"learning_rate_decay_staircase"
:
True
,
# Optimizer for training the model.
"optimizer"
:
"adam"
,
# If not None, gradient norms will be clipped to this value.
"clip_gradient_norm"
:
1
,
}
}
def
categorical
():
"""Returns a config for models with a categorical output distribution.
Input values will be clipped to {min,max}_value_for_quantization, then
linearly split into num_classes.
"""
config
=
base
()
config
[
"hparams"
][
"output_distribution"
]
=
{
"type"
:
"categorical"
,
"num_classes"
:
256
,
"min_quantization_value"
:
-
1
,
"max_quantization_value"
:
1
}
return
config
def
get_config
(
config_name
):
"""Returns config correspnding to provided name."""
if
config_name
in
[
"base"
,
"normal"
]:
return
base
()
elif
config_name
==
"categorical"
:
return
categorical
()
else
:
raise
ValueError
(
"Unrecognized config name: {}"
.
format
(
config_name
))
research/astronet/astrowavenet/data/BUILD
deleted
100644 → 0
View file @
17c2f0cc
package
(
default_visibility
=
[
"//visibility:public"
])
licenses
([
"notice"
])
# Apache 2.0
py_library
(
name
=
"base"
,
srcs
=
[
"base.py"
,
],
deps
=
[
"//astronet/ops:dataset_ops"
,
"//tf_util:configdict"
,
],
)
py_test
(
name
=
"base_test"
,
srcs
=
[
"base_test.py"
],
data
=
[
"test_data/test-dataset.tfrecord"
],
srcs_version
=
"PY2AND3"
,
deps
=
[
":base"
],
)
py_library
(
name
=
"kepler_light_curves"
,
srcs
=
[
"kepler_light_curves.py"
,
],
deps
=
[
":base"
,
"//tf_util:configdict"
,
],
)
py_library
(
name
=
"synthetic_transits"
,
srcs
=
[
"synthetic_transits.py"
,
],
deps
=
[
":base"
,
":synthetic_transit_maker"
,
"//tf_util:configdict"
,
],
)
py_library
(
name
=
"synthetic_transit_maker"
,
srcs
=
[
"synthetic_transit_maker.py"
,
],
)
py_test
(
name
=
"synthetic_transit_maker_test"
,
srcs
=
[
"synthetic_transit_maker_test.py"
],
srcs_version
=
"PY2AND3"
,
deps
=
[
":synthetic_transit_maker"
],
)
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment