Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
65da497f
Commit
65da497f
authored
Dec 13, 2018
by
Shining Sun
Browse files
Merge branch 'master' of
https://github.com/tensorflow/models
into cifar_keras
parents
93e0022d
7d032ea3
Changes
186
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
3535 deletions
+0
-3535
research/astronet/astronet/util/BUILD
research/astronet/astronet/util/BUILD
+0
-65
research/astronet/astronet/util/__init__.py
research/astronet/astronet/util/__init__.py
+0
-14
research/astronet/astronet/util/config_util.py
research/astronet/astronet/util/config_util.py
+0
-126
research/astronet/astronet/util/configdict.py
research/astronet/astronet/util/configdict.py
+0
-64
research/astronet/astronet/util/configdict_test.py
research/astronet/astronet/util/configdict_test.py
+0
-180
research/astronet/astronet/util/estimator_runner.py
research/astronet/astronet/util/estimator_runner.py
+0
-147
research/astronet/astronet/util/estimator_util.py
research/astronet/astronet/util/estimator_util.py
+0
-252
research/astronet/astronet/util/example_util.py
research/astronet/astronet/util/example_util.py
+0
-143
research/astronet/astronet/util/example_util_test.py
research/astronet/astronet/util/example_util_test.py
+0
-180
research/astronet/astrowavenet/BUILD
research/astronet/astrowavenet/BUILD
+0
-49
research/astronet/astrowavenet/README.md
research/astronet/astrowavenet/README.md
+0
-44
research/astronet/astrowavenet/__init__.py
research/astronet/astrowavenet/__init__.py
+0
-14
research/astronet/astrowavenet/astrowavenet_model.py
research/astronet/astrowavenet/astrowavenet_model.py
+0
-321
research/astronet/astrowavenet/astrowavenet_model_test.py
research/astronet/astrowavenet/astrowavenet_model_test.py
+0
-631
research/astronet/astrowavenet/configurations.py
research/astronet/astrowavenet/configurations.py
+0
-76
research/astronet/astrowavenet/data/BUILD
research/astronet/astrowavenet/data/BUILD
+0
-59
research/astronet/astrowavenet/data/__init__.py
research/astronet/astrowavenet/data/__init__.py
+0
-14
research/astronet/astrowavenet/data/base.py
research/astronet/astrowavenet/data/base.py
+0
-240
research/astronet/astrowavenet/data/base_test.py
research/astronet/astrowavenet/data/base_test.py
+0
-778
research/astronet/astrowavenet/data/synthetic_transit_maker.py
...rch/astronet/astrowavenet/data/synthetic_transit_maker.py
+0
-138
No files found.
research/astronet/astronet/util/BUILD
deleted
100644 → 0
View file @
93e0022d
package
(
default_visibility
=
[
"//visibility:public"
])
licenses
([
"notice"
])
# Apache 2.0
py_library
(
name
=
"configdict"
,
srcs
=
[
"configdict.py"
],
srcs_version
=
"PY2AND3"
,
deps
=
[
],
)
py_test
(
name
=
"configdict_test"
,
size
=
"small"
,
srcs
=
[
"configdict_test.py"
],
srcs_version
=
"PY2AND3"
,
deps
=
[
":configdict"
],
)
py_library
(
name
=
"config_util"
,
srcs
=
[
"config_util.py"
],
srcs_version
=
"PY2AND3"
,
)
py_test
(
name
=
"config_util_test"
,
size
=
"small"
,
srcs
=
[
"config_util_test.py"
],
srcs_version
=
"PY2AND3"
,
deps
=
[
":config_util"
],
)
py_library
(
name
=
"estimator_runner"
,
srcs
=
[
"estimator_runner.py"
],
srcs_version
=
"PY2AND3"
,
)
py_library
(
name
=
"estimator_util"
,
srcs
=
[
"estimator_util.py"
],
srcs_version
=
"PY2AND3"
,
deps
=
[
"//astronet/ops:dataset_ops"
,
"//astronet/ops:metrics"
,
"//astronet/ops:training"
,
],
)
py_library
(
name
=
"example_util"
,
srcs
=
[
"example_util.py"
],
srcs_version
=
"PY2AND3"
,
visibility
=
[
"//visibility:public"
],
)
py_test
(
name
=
"example_util_test"
,
size
=
"small"
,
srcs
=
[
"example_util_test.py"
],
srcs_version
=
"PY2AND3"
,
deps
=
[
":example_util"
],
)
research/astronet/astronet/util/__init__.py
deleted
100644 → 0
View file @
93e0022d
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
research/astronet/astronet/util/config_util.py
deleted
100644 → 0
View file @
93e0022d
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions for configurations."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
json
import
os.path
import
tensorflow
as
tf
def
parse_json
(
json_string_or_file
):
"""Parses values from a JSON string or JSON file.
This function is useful for command line flags containing configuration
overrides. Using this function, the flag can be passed either as a JSON string
(e.g. '{"learning_rate": 1.0}') or the path to a JSON configuration file.
Args:
json_string_or_file: A JSON serialized string OR the path to a JSON file.
Returns:
A dictionary; the parsed JSON.
Raises:
ValueError: If the JSON could not be parsed.
"""
# First, attempt to parse the string as a JSON dict.
try
:
json_dict
=
json
.
loads
(
json_string_or_file
)
except
ValueError
as
literal_json_parsing_error
:
try
:
# Otherwise, try to use it as a path to a JSON file.
with
tf
.
gfile
.
Open
(
json_string_or_file
)
as
f
:
json_dict
=
json
.
load
(
f
)
except
ValueError
as
json_file_parsing_error
:
raise
ValueError
(
"Unable to parse the content of the json file {}. "
"Parsing error: {}."
.
format
(
json_string_or_file
,
json_file_parsing_error
.
message
))
except
tf
.
gfile
.
FileError
:
message
=
(
"Unable to parse the input parameter neither as literal "
"JSON nor as the name of a file that exists.
\n
"
"JSON parsing error: {}
\n\n
Input parameter:
\n
{}."
.
format
(
literal_json_parsing_error
.
message
,
json_string_or_file
))
raise
ValueError
(
message
)
return
json_dict
def
to_json
(
config
):
"""Converts a JSON-serializable configuration object to a JSON string."""
if
hasattr
(
config
,
"to_json"
)
and
callable
(
config
.
to_json
):
return
config
.
to_json
(
indent
=
2
)
else
:
return
json
.
dumps
(
config
,
indent
=
2
)
def
log_and_save_config
(
config
,
output_dir
):
"""Logs and writes a JSON-serializable configuration object.
Args:
config: A JSON-serializable object.
output_dir: Destination directory.
"""
config_json
=
to_json
(
config
)
tf
.
logging
.
info
(
"config: %s"
,
config_json
)
tf
.
gfile
.
MakeDirs
(
output_dir
)
with
tf
.
gfile
.
Open
(
os
.
path
.
join
(
output_dir
,
"config.json"
),
"w"
)
as
f
:
f
.
write
(
config_json
)
def
unflatten
(
flat_config
):
"""Transforms a flat configuration dictionary into a nested dictionary.
Example:
{
"a": 1,
"b.c": 2,
"b.d.e": 3,
"b.d.f": 4,
}
would be transformed to:
{
"a": 1,
"b": {
"c": 2,
"d": {
"e": 3,
"f": 4,
}
}
}
Args:
flat_config: A dictionary with strings as keys where nested configuration
parameters are represented with period-separated names.
Returns:
A dictionary nested according to the keys of the input dictionary.
"""
config
=
{}
for
path
,
value
in
flat_config
.
items
():
path
=
path
.
split
(
"."
)
final_key
=
path
.
pop
()
nested_config
=
config
for
key
in
path
:
nested_config
=
nested_config
.
setdefault
(
key
,
{})
nested_config
[
final_key
]
=
value
return
config
research/astronet/astronet/util/configdict.py
deleted
100644 → 0
View file @
93e0022d
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration container for TensorFlow models.
A ConfigDict is simply a dict whose values can be accessed via both dot syntax
(config.key) and dict syntax (config['key']).
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
def
_maybe_convert_dict
(
value
):
if
isinstance
(
value
,
dict
):
return
ConfigDict
(
value
)
return
value
class
ConfigDict
(
dict
):
"""Configuration container class."""
def
__init__
(
self
,
initial_dictionary
=
None
):
"""Creates an instance of ConfigDict.
Args:
initial_dictionary: Optional dictionary or ConfigDict containing initial
parameters.
"""
if
initial_dictionary
:
for
field
,
value
in
initial_dictionary
.
items
():
initial_dictionary
[
field
]
=
_maybe_convert_dict
(
value
)
super
(
ConfigDict
,
self
).
__init__
(
initial_dictionary
)
def
__setattr__
(
self
,
attribute
,
value
):
self
[
attribute
]
=
_maybe_convert_dict
(
value
)
def
__getattr__
(
self
,
attribute
):
try
:
return
self
[
attribute
]
except
KeyError
as
e
:
raise
AttributeError
(
e
)
def
__delattr__
(
self
,
attribute
):
try
:
del
self
[
attribute
]
except
KeyError
as
e
:
raise
AttributeError
(
e
)
def
__setitem__
(
self
,
key
,
value
):
super
(
ConfigDict
,
self
).
__setitem__
(
key
,
_maybe_convert_dict
(
value
))
research/astronet/astronet/util/configdict_test.py
deleted
100644 → 0
View file @
93e0022d
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for config_util.configdict."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
absl.testing
import
absltest
from
astronet.util
import
configdict
class
ConfigDictTest
(
absltest
.
TestCase
):
def
setUp
(
self
):
super
(
ConfigDictTest
,
self
).
setUp
()
self
.
_config
=
configdict
.
ConfigDict
({
"int"
:
1
,
"float"
:
2.0
,
"bool"
:
True
,
"str"
:
"hello"
,
"nested"
:
{
"int"
:
3
,
},
"double_nested"
:
{
"a"
:
{
"int"
:
3
,
},
"b"
:
{
"float"
:
4.0
,
}
}
})
def
testAccess
(
self
):
# Simple types.
self
.
assertEqual
(
1
,
self
.
_config
.
int
)
self
.
assertEqual
(
1
,
self
.
_config
[
"int"
])
self
.
assertEqual
(
2.0
,
self
.
_config
.
float
)
self
.
assertEqual
(
2.0
,
self
.
_config
[
"float"
])
self
.
assertTrue
(
self
.
_config
.
bool
)
self
.
assertTrue
(
self
.
_config
[
"bool"
])
self
.
assertEqual
(
"hello"
,
self
.
_config
.
str
)
self
.
assertEqual
(
"hello"
,
self
.
_config
[
"str"
])
# Single nested config.
self
.
assertEqual
(
3
,
self
.
_config
.
nested
.
int
)
self
.
assertEqual
(
3
,
self
.
_config
[
"nested"
].
int
)
self
.
assertEqual
(
3
,
self
.
_config
.
nested
[
"int"
])
self
.
assertEqual
(
3
,
self
.
_config
[
"nested"
][
"int"
])
# Double nested config.
self
.
assertEqual
(
3
,
self
.
_config
[
"double_nested"
].
a
.
int
)
self
.
assertEqual
(
3
,
self
.
_config
[
"double_nested"
][
"a"
].
int
)
self
.
assertEqual
(
3
,
self
.
_config
[
"double_nested"
].
a
[
"int"
])
self
.
assertEqual
(
3
,
self
.
_config
[
"double_nested"
][
"a"
][
"int"
])
self
.
assertEqual
(
4.0
,
self
.
_config
.
double_nested
.
b
.
float
)
self
.
assertEqual
(
4.0
,
self
.
_config
.
double_nested
[
"b"
].
float
)
self
.
assertEqual
(
4.0
,
self
.
_config
.
double_nested
.
b
[
"float"
])
self
.
assertEqual
(
4.0
,
self
.
_config
.
double_nested
[
"b"
][
"float"
])
# Nonexistent parameters.
with
self
.
assertRaises
(
AttributeError
):
_
=
self
.
_config
.
nonexistent
with
self
.
assertRaises
(
KeyError
):
_
=
self
.
_config
[
"nonexistent"
]
def
testSetAttribut
(
self
):
# Overwrite existing simple type.
self
.
_config
.
int
=
40
self
.
assertEqual
(
40
,
self
.
_config
.
int
)
# Overwrite existing nested simple type.
self
.
_config
.
nested
.
int
=
40
self
.
assertEqual
(
40
,
self
.
_config
.
nested
.
int
)
# Overwrite existing nested config.
self
.
_config
.
double_nested
.
a
=
{
"float"
:
50.0
}
self
.
assertIsInstance
(
self
.
_config
.
double_nested
.
a
,
configdict
.
ConfigDict
)
self
.
assertEqual
(
50.0
,
self
.
_config
.
double_nested
.
a
.
float
)
self
.
assertNotIn
(
"int"
,
self
.
_config
.
double_nested
.
a
)
# Set new simple type.
self
.
_config
.
int_2
=
10
self
.
assertEqual
(
10
,
self
.
_config
.
int_2
)
# Set new nested simple type.
self
.
_config
.
nested
.
int_2
=
20
self
.
assertEqual
(
20
,
self
.
_config
.
nested
.
int_2
)
# Set new nested config.
self
.
_config
.
double_nested
.
c
=
{
"int"
:
30
}
self
.
assertIsInstance
(
self
.
_config
.
double_nested
.
c
,
configdict
.
ConfigDict
)
self
.
assertEqual
(
30
,
self
.
_config
.
double_nested
.
c
.
int
)
def
testSetItem
(
self
):
# Overwrite existing simple type.
self
.
_config
[
"int"
]
=
40
self
.
assertEqual
(
40
,
self
.
_config
.
int
)
# Overwrite existing nested simple type.
self
.
_config
[
"nested"
].
int
=
40
self
.
assertEqual
(
40
,
self
.
_config
.
nested
.
int
)
self
.
_config
.
nested
[
"int"
]
=
50
self
.
assertEqual
(
50
,
self
.
_config
.
nested
.
int
)
# Overwrite existing nested config.
self
.
_config
.
double_nested
[
"a"
]
=
{
"float"
:
50.0
}
self
.
assertIsInstance
(
self
.
_config
.
double_nested
.
a
,
configdict
.
ConfigDict
)
self
.
assertEqual
(
50.0
,
self
.
_config
.
double_nested
.
a
.
float
)
self
.
assertNotIn
(
"int"
,
self
.
_config
.
double_nested
.
a
)
# Set new simple type.
self
.
_config
[
"int_2"
]
=
10
self
.
assertEqual
(
10
,
self
.
_config
.
int_2
)
# Set new nested simple type.
self
.
_config
.
nested
[
"int_2"
]
=
20
self
.
assertEqual
(
20
,
self
.
_config
.
nested
.
int_2
)
self
.
_config
.
nested
[
"int_3"
]
=
30
self
.
assertEqual
(
30
,
self
.
_config
.
nested
.
int_3
)
# Set new nested config.
self
.
_config
.
double_nested
[
"c"
]
=
{
"int"
:
30
}
self
.
assertIsInstance
(
self
.
_config
.
double_nested
.
c
,
configdict
.
ConfigDict
)
self
.
assertEqual
(
30
,
self
.
_config
.
double_nested
.
c
.
int
)
def
testDelete
(
self
):
# Simple types.
self
.
assertEqual
(
1
,
self
.
_config
.
int
)
del
self
.
_config
.
int
with
self
.
assertRaises
(
AttributeError
):
_
=
self
.
_config
.
int
with
self
.
assertRaises
(
KeyError
):
_
=
self
.
_config
[
"int"
]
self
.
assertEqual
(
2.0
,
self
.
_config
[
"float"
])
del
self
.
_config
[
"float"
]
with
self
.
assertRaises
(
AttributeError
):
_
=
self
.
_config
.
float
with
self
.
assertRaises
(
KeyError
):
_
=
self
.
_config
[
"float"
]
# Nested config.
self
.
assertEqual
(
3
,
self
.
_config
.
nested
.
int
)
del
self
.
_config
.
nested
with
self
.
assertRaises
(
AttributeError
):
_
=
self
.
_config
.
nested
with
self
.
assertRaises
(
KeyError
):
_
=
self
.
_config
[
"nested"
]
if
__name__
==
"__main__"
:
absltest
.
main
()
research/astronet/astronet/util/estimator_runner.py
deleted
100644 → 0
View file @
93e0022d
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for training and evaluation using a TensorFlow Estimator."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
def
evaluate
(
estimator
,
eval_args
):
"""Runs evaluation on the latest model checkpoint.
Args:
estimator: Instance of tf.Estimator.
eval_args: Dictionary of {eval_name: (input_fn, eval_steps)} where eval_name
is the name of the evaluation set (e.g. "train" or "val"), input_fn is an
input function returning a tuple (features, labels), and eval_steps is the
number of steps for which to evaluate the model (if None, evaluates until
input_fn raises an end-of-input exception).
Returns:
global_step: The global step of the checkpoint evaluated.
values: A dict of metric values from the evaluation. May be empty, e.g. if
the training job has not yet saved a checkpoint or the checkpoint is
deleted by the time the TPU worker initializes.
"""
# Default return values if evaluation fails.
global_step
=
None
values
=
{}
latest_checkpoint
=
estimator
.
latest_checkpoint
()
if
not
latest_checkpoint
:
# This is expected if the training job has not yet saved a checkpoint.
return
global_step
,
values
tf
.
logging
.
info
(
"Starting evaluation on checkpoint %s"
,
latest_checkpoint
)
try
:
for
eval_name
,
(
input_fn
,
eval_steps
)
in
eval_args
.
items
():
values
[
eval_name
]
=
estimator
.
evaluate
(
input_fn
,
steps
=
eval_steps
,
name
=
eval_name
)
if
global_step
is
None
:
global_step
=
values
[
eval_name
].
get
(
"global_step"
)
except
(
tf
.
errors
.
NotFoundError
,
ValueError
):
# Expected under some conditions, e.g. checkpoint is already deleted by the
# trainer process. Increasing RunConfig.keep_checkpoint_max may prevent this
# in some cases.
tf
.
logging
.
info
(
"Checkpoint %s no longer exists, skipping evaluation."
,
latest_checkpoint
)
return
global_step
,
values
def
continuous_eval
(
estimator
,
eval_args
,
train_steps
=
None
,
timeout_secs
=
None
,
timeout_fn
=
None
):
"""Runs evaluation whenever there's a new checkpoint.
Args:
estimator: Instance of tf.Estimator.
eval_args: Dictionary of {eval_name: (input_fn, eval_steps)} where eval_name
is the name of the evaluation set (e.g. "train" or "val"), input_fn is an
input function returning a tuple (features, labels), and eval_steps is the
number of steps for which to evaluate the model (if None, evaluates until
input_fn raises an end-of-input exception).
train_steps: The number of steps the model will train for. This function
will terminate once the model has finished training.
timeout_secs: Number of seconds to wait for new checkpoints. If None, wait
indefinitely.
timeout_fn: Optional function to call after timeout. The iterator will exit
if and only if the function returns True.
Yields:
A dict of metric values from each evaluation. May be empty, e.g. if the
training job has not yet saved a checkpoint or the checkpoint is deleted by
the time the TPU worker initializes.
"""
for
_
in
tf
.
contrib
.
training
.
checkpoints_iterator
(
estimator
.
model_dir
,
timeout
=
timeout_secs
,
timeout_fn
=
timeout_fn
):
global_step
,
values
=
evaluate
(
estimator
,
eval_args
)
yield
global_step
,
values
global_step
=
global_step
or
0
# Ensure global_step is not None.
if
train_steps
and
global_step
>=
train_steps
:
break
def
continuous_train_and_eval
(
estimator
,
train_input_fn
,
eval_args
,
local_eval_frequency
=
None
,
train_hooks
=
None
,
train_steps
=
None
):
"""Alternates training and evaluation.
Args:
estimator: Instance of tf.Estimator.
train_input_fn: Input function returning a tuple (features, labels).
eval_args: Dictionary of {eval_name: (input_fn, eval_steps)} where eval_name
is the name of the evaluation set (e.g. "train" or "val"), input_fn is an
input function returning a tuple (features, labels), and eval_steps is the
number of steps for which to evaluate the model (if None, evaluates until
input_fn raises an end-of-input exception).
local_eval_frequency: The number of training steps between evaluations. If
None, trains until train_input_fn raises an end-of-input exception.
train_hooks: List of SessionRunHook subclass instances. Used for callbacks
inside the training call.
train_steps: The total number of steps to train the model for.
Yields:
A dict of metric values from each evaluation. May be empty, e.g. if the
training job has not yet saved a checkpoint or the checkpoint is deleted by
the time the TPU worker initializes.
"""
while
True
:
# We run evaluation before training in this loop to prevent evaluation from
# being skipped if the process is interrupted.
global_step
,
values
=
evaluate
(
estimator
,
eval_args
)
yield
global_step
,
values
global_step
=
global_step
or
0
# Ensure global_step is not None.
if
train_steps
and
global_step
>=
train_steps
:
break
# Decide how many steps before the next evaluation.
steps
=
local_eval_frequency
if
train_steps
:
remaining_steps
=
train_steps
-
global_step
steps
=
min
(
steps
,
remaining_steps
)
if
steps
else
remaining_steps
tf
.
logging
.
info
(
"Starting training at global step %d"
,
global_step
)
estimator
.
train
(
train_input_fn
,
hooks
=
train_hooks
,
steps
=
steps
)
research/astronet/astronet/util/estimator_util.py
deleted
100644 → 0
View file @
93e0022d
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper functions for creating a TensorFlow Estimator."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
copy
import
tensorflow
as
tf
from
astronet.ops
import
dataset_ops
from
astronet.ops
import
metrics
from
astronet.ops
import
training
class
_InputFn
(
object
):
"""Class that acts as a callable input function for Estimator train / eval."""
def
__init__
(
self
,
file_pattern
,
input_config
,
mode
,
shuffle_values_buffer
=
0
,
repeat
=
1
):
"""Initializes the input function.
Args:
file_pattern: File pattern matching input TFRecord files, e.g.
"/tmp/train-?????-of-00100". May also be a comma-separated list of file
patterns.
input_config: ConfigDict containing feature and label specifications.
mode: A tf.estimator.ModeKeys.
shuffle_values_buffer: If > 0, shuffle examples using a buffer of this
size.
repeat: The number of times to repeat the dataset. If None or -1 the
elements will be repeated indefinitely.
"""
self
.
_file_pattern
=
file_pattern
self
.
_input_config
=
input_config
self
.
_mode
=
mode
self
.
_shuffle_values_buffer
=
shuffle_values_buffer
self
.
_repeat
=
repeat
def
__call__
(
self
,
config
,
params
):
"""Builds the input pipeline."""
# Infer whether this input_fn was called by Estimator or TPUEstimator using
# the config type.
use_tpu
=
isinstance
(
config
,
tf
.
contrib
.
tpu
.
RunConfig
)
mode
=
self
.
_mode
include_labels
=
(
mode
in
[
tf
.
estimator
.
ModeKeys
.
TRAIN
,
tf
.
estimator
.
ModeKeys
.
EVAL
])
reverse_time_series_prob
=
0.5
if
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
else
0
shuffle_filenames
=
(
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
)
dataset
=
dataset_ops
.
build_dataset
(
file_pattern
=
self
.
_file_pattern
,
input_config
=
self
.
_input_config
,
batch_size
=
params
[
"batch_size"
],
include_labels
=
include_labels
,
reverse_time_series_prob
=
reverse_time_series_prob
,
shuffle_filenames
=
shuffle_filenames
,
shuffle_values_buffer
=
self
.
_shuffle_values_buffer
,
repeat
=
self
.
_repeat
,
use_tpu
=
use_tpu
)
return
dataset
def
create_input_fn
(
file_pattern
,
input_config
,
mode
,
shuffle_values_buffer
=
0
,
repeat
=
1
):
"""Creates an input_fn that reads a dataset from sharded TFRecord files.
Args:
file_pattern: File pattern matching input TFRecord files, e.g.
"/tmp/train-?????-of-00100". May also be a comma-separated list of file
patterns.
input_config: ConfigDict containing feature and label specifications.
mode: A tf.estimator.ModeKeys.
shuffle_values_buffer: If > 0, shuffle examples using a buffer of this size.
repeat: The number of times to repeat the dataset. If None or -1 the
elements will be repeated indefinitely.
Returns:
A callable that builds the input pipeline and returns a tf.data.Dataset
object.
"""
return
_InputFn
(
file_pattern
,
input_config
,
mode
,
shuffle_values_buffer
,
repeat
)
class
_ModelFn
(
object
):
"""Class that acts as a callable model function for Estimator train / eval."""
def
__init__
(
self
,
model_class
,
hparams
,
use_tpu
=
False
):
"""Initializes the model function.
Args:
model_class: Model class.
hparams: ConfigDict containing hyperparameters for building and training
the model.
use_tpu: If True, a TPUEstimator will be returned. Otherwise an Estimator
will be returned.
"""
self
.
_model_class
=
model_class
self
.
_base_hparams
=
hparams
self
.
_use_tpu
=
use_tpu
def
__call__
(
self
,
features
,
labels
,
mode
,
params
):
"""Builds the model and returns an EstimatorSpec or TPUEstimatorSpec."""
hparams
=
copy
.
deepcopy
(
self
.
_base_hparams
)
if
"batch_size"
in
params
:
hparams
.
batch_size
=
params
[
"batch_size"
]
# Allow labels to be passed in the features dictionary.
if
"labels"
in
features
:
if
labels
is
not
None
and
labels
is
not
features
[
"labels"
]:
raise
ValueError
(
"Conflicting labels: features['labels'] = {}, labels = {}"
.
format
(
features
[
"labels"
],
labels
))
labels
=
features
.
pop
(
"labels"
)
model
=
self
.
_model_class
(
features
,
labels
,
hparams
,
mode
)
model
.
build
()
# Possibly create train_op.
use_tpu
=
self
.
_use_tpu
train_op
=
None
if
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
:
learning_rate
=
training
.
create_learning_rate
(
hparams
,
model
.
global_step
)
optimizer
=
training
.
create_optimizer
(
hparams
,
learning_rate
,
use_tpu
)
train_op
=
training
.
create_train_op
(
model
,
optimizer
)
# Possibly create evaluation metrics.
eval_metrics
=
None
if
mode
==
tf
.
estimator
.
ModeKeys
.
EVAL
:
eval_metrics
=
(
metrics
.
create_metric_fn
(
model
)
if
use_tpu
else
metrics
.
create_metrics
(
model
))
if
use_tpu
:
estimator
=
tf
.
contrib
.
tpu
.
TPUEstimatorSpec
(
mode
=
mode
,
predictions
=
model
.
predictions
,
loss
=
model
.
total_loss
,
train_op
=
train_op
,
eval_metrics
=
eval_metrics
)
else
:
estimator
=
tf
.
estimator
.
EstimatorSpec
(
mode
=
mode
,
predictions
=
model
.
predictions
,
loss
=
model
.
total_loss
,
train_op
=
train_op
,
eval_metric_ops
=
eval_metrics
)
return
estimator
def
create_model_fn
(
model_class
,
hparams
,
use_tpu
=
False
):
"""Wraps model_class as an Estimator or TPUEstimator model_fn.
Args:
model_class: AstroModel or a subclass.
hparams: ConfigDict of configuration parameters for building the model.
use_tpu: If True, a TPUEstimator model_fn is returned. Otherwise an
Estimator model_fn is returned.
Returns:
model_fn: A callable that constructs the model and returns a
TPUEstimatorSpec if use_tpu is True, otherwise an EstimatorSpec.
"""
return
_ModelFn
(
model_class
,
hparams
,
use_tpu
)
def
create_estimator
(
model_class
,
hparams
,
run_config
=
None
,
model_dir
=
None
,
eval_batch_size
=
None
):
"""Wraps model_class as an Estimator or TPUEstimator.
If run_config is None or a tf.estimator.RunConfig, an Estimator is returned.
If run_config is a tf.contrib.tpu.RunConfig, a TPUEstimator is returned.
Args:
model_class: AstroModel or a subclass.
hparams: ConfigDict of configuration parameters for building the model.
run_config: Optional tf.estimator.RunConfig or tf.contrib.tpu.RunConfig.
model_dir: Optional directory for saving the model. If not passed
explicitly, it must be specified in run_config.
eval_batch_size: Optional batch size for evaluation on TPU. Only applicable
if run_config is a tf.contrib.tpu.RunConfig. Defaults to
hparams.batch_size.
Returns:
An Estimator object if run_config is None or a tf.estimator.RunConfig, or a
TPUEstimator object if run_config is a tf.contrib.tpu.RunConfig.
Raises:
ValueError:
If model_dir is not passed explicitly or in run_config.model_dir, or if
eval_batch_size is specified and run_config is not a
tf.contrib.tpu.RunConfig.
"""
if
run_config
is
None
:
run_config
=
tf
.
estimator
.
RunConfig
()
else
:
run_config
=
copy
.
deepcopy
(
run_config
)
if
not
model_dir
and
not
run_config
.
model_dir
:
raise
ValueError
(
"model_dir must be passed explicitly or specified in run_config"
)
use_tpu
=
isinstance
(
run_config
,
tf
.
contrib
.
tpu
.
RunConfig
)
model_fn
=
create_model_fn
(
model_class
,
hparams
,
use_tpu
)
if
use_tpu
:
eval_batch_size
=
eval_batch_size
or
hparams
.
batch_size
estimator
=
tf
.
contrib
.
tpu
.
TPUEstimator
(
model_fn
=
model_fn
,
model_dir
=
model_dir
,
config
=
run_config
,
train_batch_size
=
hparams
.
batch_size
,
eval_batch_size
=
eval_batch_size
)
else
:
if
eval_batch_size
is
not
None
:
raise
ValueError
(
"eval_batch_size can only be specified for TPU."
)
estimator
=
tf
.
estimator
.
Estimator
(
model_fn
=
model_fn
,
model_dir
=
model_dir
,
config
=
run_config
,
params
=
{
"batch_size"
:
hparams
.
batch_size
})
return
estimator
research/astronet/astronet/util/example_util.py
deleted
100644 → 0
View file @
93e0022d
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helpers for getting and setting values in tf.Example protocol buffers."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
def
get_feature
(
ex
,
name
,
kind
=
None
,
strict
=
True
):
"""Gets a feature value from a tf.train.Example.
Args:
ex: A tf.train.Example.
name: Name of the feature to look up.
kind: Optional: one of 'bytes_list', 'float_list', 'int64_list'. Inferred if
not specified.
strict: Whether to raise a KeyError if there is no such feature.
Returns:
A numpy array containing to the values of the specified feature.
Raises:
KeyError: If there is no feature with the specified name.
TypeError: If the feature has a different type to that specified.
"""
if
name
not
in
ex
.
features
.
feature
:
if
strict
:
raise
KeyError
(
name
)
return
np
.
array
([])
inferred_kind
=
ex
.
features
.
feature
[
name
].
WhichOneof
(
"kind"
)
if
not
inferred_kind
:
return
np
.
array
([])
# Feature exists, but it's empty.
if
kind
and
kind
!=
inferred_kind
:
raise
TypeError
(
"Requested {}, but Feature has {}"
.
format
(
kind
,
inferred_kind
))
return
np
.
array
(
getattr
(
ex
.
features
.
feature
[
name
],
inferred_kind
).
value
)
def
get_bytes_feature
(
ex
,
name
,
strict
=
True
):
"""Gets the value of a bytes feature from a tf.train.Example."""
return
get_feature
(
ex
,
name
,
"bytes_list"
,
strict
)
def
get_float_feature
(
ex
,
name
,
strict
=
True
):
"""Gets the value of a float feature from a tf.train.Example."""
return
get_feature
(
ex
,
name
,
"float_list"
,
strict
)
def
get_int64_feature
(
ex
,
name
,
strict
=
True
):
"""Gets the value of an int64 feature from a tf.train.Example."""
return
get_feature
(
ex
,
name
,
"int64_list"
,
strict
)
def
_infer_kind
(
value
):
"""Infers the tf.train.Feature kind from a value."""
if
np
.
issubdtype
(
type
(
value
[
0
]),
np
.
integer
):
return
"int64_list"
try
:
float
(
value
[
0
])
return
"float_list"
except
ValueError
:
return
"bytes_list"
def
set_feature
(
ex
,
name
,
value
,
kind
=
None
,
allow_overwrite
=
False
,
bytes_encoding
=
"latin-1"
):
"""Sets a feature value in a tf.train.Example.
Args:
ex: A tf.train.Example.
name: Name of the feature to set.
value: Feature value to set. Must be a sequence.
kind: Optional: one of 'bytes_list', 'float_list', 'int64_list'. Inferred if
not specified.
allow_overwrite: Whether to overwrite the existing value of the feature.
bytes_encoding: Codec for encoding strings when kind = 'bytes_list'.
Raises:
ValueError: If `allow_overwrite` is False and the feature already exists, or
if `kind` is unrecognized.
"""
if
name
in
ex
.
features
.
feature
:
if
allow_overwrite
:
del
ex
.
features
.
feature
[
name
]
else
:
raise
ValueError
(
"Attempting to overwrite feature with name: {}. "
"Set allow_overwrite=True if this is desired."
.
format
(
name
))
if
not
kind
:
kind
=
_infer_kind
(
value
)
if
kind
==
"bytes_list"
:
value
=
[
str
(
v
).
encode
(
bytes_encoding
)
for
v
in
value
]
elif
kind
==
"float_list"
:
value
=
[
float
(
v
)
for
v
in
value
]
elif
kind
==
"int64_list"
:
value
=
[
int
(
v
)
for
v
in
value
]
else
:
raise
ValueError
(
"Unrecognized kind: {}"
.
format
(
kind
))
getattr
(
ex
.
features
.
feature
[
name
],
kind
).
value
.
extend
(
value
)
def
set_float_feature
(
ex
,
name
,
value
,
allow_overwrite
=
False
):
"""Sets the value of a float feature in a tf.train.Example."""
set_feature
(
ex
,
name
,
value
,
"float_list"
,
allow_overwrite
)
def
set_bytes_feature
(
ex
,
name
,
value
,
allow_overwrite
=
False
,
bytes_encoding
=
"latin-1"
):
"""Sets the value of a bytes feature in a tf.train.Example."""
set_feature
(
ex
,
name
,
value
,
"bytes_list"
,
allow_overwrite
,
bytes_encoding
)
def
set_int64_feature
(
ex
,
name
,
value
,
allow_overwrite
=
False
):
"""Sets the value of an int64 feature in a tf.train.Example."""
set_feature
(
ex
,
name
,
value
,
"int64_list"
,
allow_overwrite
)
research/astronet/astronet/util/example_util_test.py
deleted
100644 → 0
View file @
93e0022d
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for example_util.py."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
tensorflow
as
tf
from
astronet.util
import
example_util
class
ExampleUtilTest
(
tf
.
test
.
TestCase
):
def
test_get_feature
(
self
):
# Create Example.
bytes_list
=
tf
.
train
.
BytesList
(
value
=
[
v
.
encode
(
"latin-1"
)
for
v
in
[
"a"
,
"b"
,
"c"
]])
float_list
=
tf
.
train
.
FloatList
(
value
=
[
1.0
,
2.0
,
3.0
])
int64_list
=
tf
.
train
.
Int64List
(
value
=
[
11
,
22
,
33
])
ex
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
{
"a_bytes"
:
tf
.
train
.
Feature
(
bytes_list
=
bytes_list
),
"b_float"
:
tf
.
train
.
Feature
(
float_list
=
float_list
),
"c_int64"
:
tf
.
train
.
Feature
(
int64_list
=
int64_list
),
"d_empty"
:
tf
.
train
.
Feature
(),
}))
# Get bytes feature.
np
.
testing
.
assert_array_equal
(
example_util
.
get_feature
(
ex
,
"a_bytes"
).
astype
(
str
),
[
"a"
,
"b"
,
"c"
])
np
.
testing
.
assert_array_equal
(
example_util
.
get_feature
(
ex
,
"a_bytes"
,
"bytes_list"
).
astype
(
str
),
[
"a"
,
"b"
,
"c"
])
np
.
testing
.
assert_array_equal
(
example_util
.
get_bytes_feature
(
ex
,
"a_bytes"
).
astype
(
str
),
[
"a"
,
"b"
,
"c"
])
with
self
.
assertRaises
(
TypeError
):
example_util
.
get_feature
(
ex
,
"a_bytes"
,
"float_list"
)
with
self
.
assertRaises
(
TypeError
):
example_util
.
get_float_feature
(
ex
,
"a_bytes"
)
with
self
.
assertRaises
(
TypeError
):
example_util
.
get_int64_feature
(
ex
,
"a_bytes"
)
# Get float feature.
np
.
testing
.
assert_array_almost_equal
(
example_util
.
get_feature
(
ex
,
"b_float"
),
[
1.0
,
2.0
,
3.0
])
np
.
testing
.
assert_array_almost_equal
(
example_util
.
get_feature
(
ex
,
"b_float"
,
"float_list"
),
[
1.0
,
2.0
,
3.0
])
np
.
testing
.
assert_array_almost_equal
(
example_util
.
get_float_feature
(
ex
,
"b_float"
),
[
1.0
,
2.0
,
3.0
])
with
self
.
assertRaises
(
TypeError
):
example_util
.
get_feature
(
ex
,
"b_float"
,
"int64_list"
)
with
self
.
assertRaises
(
TypeError
):
example_util
.
get_bytes_feature
(
ex
,
"b_float"
)
with
self
.
assertRaises
(
TypeError
):
example_util
.
get_int64_feature
(
ex
,
"b_float"
)
# Get int64 feature.
np
.
testing
.
assert_array_equal
(
example_util
.
get_feature
(
ex
,
"c_int64"
),
[
11
,
22
,
33
])
np
.
testing
.
assert_array_equal
(
example_util
.
get_feature
(
ex
,
"c_int64"
,
"int64_list"
),
[
11
,
22
,
33
])
np
.
testing
.
assert_array_equal
(
example_util
.
get_int64_feature
(
ex
,
"c_int64"
),
[
11
,
22
,
33
])
with
self
.
assertRaises
(
TypeError
):
example_util
.
get_feature
(
ex
,
"c_int64"
,
"bytes_list"
)
with
self
.
assertRaises
(
TypeError
):
example_util
.
get_bytes_feature
(
ex
,
"c_int64"
)
with
self
.
assertRaises
(
TypeError
):
example_util
.
get_float_feature
(
ex
,
"c_int64"
)
# Get empty feature.
np
.
testing
.
assert_array_equal
(
example_util
.
get_feature
(
ex
,
"d_empty"
),
[])
np
.
testing
.
assert_array_equal
(
example_util
.
get_feature
(
ex
,
"d_empty"
,
"float_list"
),
[])
np
.
testing
.
assert_array_equal
(
example_util
.
get_bytes_feature
(
ex
,
"d_empty"
),
[])
np
.
testing
.
assert_array_equal
(
example_util
.
get_float_feature
(
ex
,
"d_empty"
),
[])
np
.
testing
.
assert_array_equal
(
example_util
.
get_int64_feature
(
ex
,
"d_empty"
),
[])
# Get nonexistent feature.
with
self
.
assertRaises
(
KeyError
):
example_util
.
get_feature
(
ex
,
"nonexistent"
)
with
self
.
assertRaises
(
KeyError
):
example_util
.
get_feature
(
ex
,
"nonexistent"
,
"bytes_list"
)
with
self
.
assertRaises
(
KeyError
):
example_util
.
get_bytes_feature
(
ex
,
"nonexistent"
)
with
self
.
assertRaises
(
KeyError
):
example_util
.
get_float_feature
(
ex
,
"nonexistent"
)
with
self
.
assertRaises
(
KeyError
):
example_util
.
get_int64_feature
(
ex
,
"nonexistent"
)
np
.
testing
.
assert_array_equal
(
example_util
.
get_feature
(
ex
,
"nonexistent"
,
strict
=
False
),
[])
np
.
testing
.
assert_array_equal
(
example_util
.
get_bytes_feature
(
ex
,
"nonexistent"
,
strict
=
False
),
[])
np
.
testing
.
assert_array_equal
(
example_util
.
get_float_feature
(
ex
,
"nonexistent"
,
strict
=
False
),
[])
np
.
testing
.
assert_array_equal
(
example_util
.
get_int64_feature
(
ex
,
"nonexistent"
,
strict
=
False
),
[])
def
test_set_feature
(
self
):
ex
=
tf
.
train
.
Example
()
# Set bytes features.
example_util
.
set_feature
(
ex
,
"a1_bytes"
,
[
"a"
,
"b"
])
example_util
.
set_feature
(
ex
,
"a2_bytes"
,
[
"A"
,
"B"
],
kind
=
"bytes_list"
)
example_util
.
set_bytes_feature
(
ex
,
"a3_bytes"
,
[
"x"
,
"y"
])
np
.
testing
.
assert_array_equal
(
np
.
array
(
ex
.
features
.
feature
[
"a1_bytes"
].
bytes_list
.
value
).
astype
(
str
),
[
"a"
,
"b"
])
np
.
testing
.
assert_array_equal
(
np
.
array
(
ex
.
features
.
feature
[
"a2_bytes"
].
bytes_list
.
value
).
astype
(
str
),
[
"A"
,
"B"
])
np
.
testing
.
assert_array_equal
(
np
.
array
(
ex
.
features
.
feature
[
"a3_bytes"
].
bytes_list
.
value
).
astype
(
str
),
[
"x"
,
"y"
])
with
self
.
assertRaises
(
ValueError
):
example_util
.
set_feature
(
ex
,
"a3_bytes"
,
[
"xxx"
])
# Duplicate.
# Set float features.
example_util
.
set_feature
(
ex
,
"b1_float"
,
[
1.0
,
2.0
])
example_util
.
set_feature
(
ex
,
"b2_float"
,
[
10.0
,
20.0
],
kind
=
"float_list"
)
example_util
.
set_float_feature
(
ex
,
"b3_float"
,
[
88.0
,
99.0
])
np
.
testing
.
assert_array_almost_equal
(
ex
.
features
.
feature
[
"b1_float"
].
float_list
.
value
,
[
1.0
,
2.0
])
np
.
testing
.
assert_array_almost_equal
(
ex
.
features
.
feature
[
"b2_float"
].
float_list
.
value
,
[
10.0
,
20.0
])
np
.
testing
.
assert_array_almost_equal
(
ex
.
features
.
feature
[
"b3_float"
].
float_list
.
value
,
[
88.0
,
99.0
])
with
self
.
assertRaises
(
ValueError
):
example_util
.
set_feature
(
ex
,
"b3_float"
,
[
1234.0
])
# Duplicate.
# Set int64 features.
example_util
.
set_feature
(
ex
,
"c1_int64"
,
[
1
,
2
,
3
])
example_util
.
set_feature
(
ex
,
"c2_int64"
,
[
11
,
22
,
33
],
kind
=
"int64_list"
)
example_util
.
set_int64_feature
(
ex
,
"c3_int64"
,
[
88
,
99
])
np
.
testing
.
assert_array_equal
(
ex
.
features
.
feature
[
"c1_int64"
].
int64_list
.
value
,
[
1
,
2
,
3
])
np
.
testing
.
assert_array_equal
(
ex
.
features
.
feature
[
"c2_int64"
].
int64_list
.
value
,
[
11
,
22
,
33
])
np
.
testing
.
assert_array_equal
(
ex
.
features
.
feature
[
"c3_int64"
].
int64_list
.
value
,
[
88
,
99
])
with
self
.
assertRaises
(
ValueError
):
example_util
.
set_feature
(
ex
,
"c3_int64"
,
[
1234
])
# Duplicate.
# Overwrite features.
example_util
.
set_feature
(
ex
,
"a3_bytes"
,
[
"xxx"
],
allow_overwrite
=
True
)
np
.
testing
.
assert_array_equal
(
np
.
array
(
ex
.
features
.
feature
[
"a3_bytes"
].
bytes_list
.
value
).
astype
(
str
),
[
"xxx"
])
example_util
.
set_feature
(
ex
,
"b3_float"
,
[
1234.0
],
allow_overwrite
=
True
)
np
.
testing
.
assert_array_almost_equal
(
ex
.
features
.
feature
[
"b3_float"
].
float_list
.
value
,
[
1234.0
])
example_util
.
set_feature
(
ex
,
"c3_int64"
,
[
1234
],
allow_overwrite
=
True
)
np
.
testing
.
assert_array_equal
(
ex
.
features
.
feature
[
"c3_int64"
].
int64_list
.
value
,
[
1234
])
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
research/astronet/astrowavenet/BUILD
deleted
100644 → 0
View file @
93e0022d
"""A TensorFlow model for generative modeling of light curves."""
package
(
default_visibility
=
[
"//visibility:public"
])
licenses
([
"notice"
])
# Apache 2.0
py_binary
(
name
=
"trainer"
,
srcs
=
[
"trainer.py"
],
srcs_version
=
"PY2AND3"
,
deps
=
[
":astrowavenet_model"
,
":configurations"
,
"//astronet/util:config_util"
,
"//astronet/util:configdict"
,
"//astronet/util:estimator_runner"
,
"//astrowavenet/data:kepler_light_curves"
,
"//astrowavenet/data:synthetic_transits"
,
"//astrowavenet/util:estimator_util"
,
],
)
py_library
(
name
=
"configurations"
,
srcs
=
[
"configurations.py"
],
srcs_version
=
"PY2AND3"
,
)
py_library
(
name
=
"astrowavenet_model"
,
srcs
=
[
"astrowavenet_model.py"
,
],
srcs_version
=
"PY2AND3"
,
)
py_test
(
name
=
"astrowavenet_model_test"
,
size
=
"small"
,
srcs
=
[
"astrowavenet_model_test.py"
,
],
srcs_version
=
"PY2AND3"
,
deps
=
[
":astrowavenet_model"
,
":configurations"
,
"//astronet/util:configdict"
,
],
)
research/astronet/astrowavenet/README.md
deleted
100644 → 0
View file @
93e0022d
# AstroWaveNet: A generative model for light curves.
Implementation based on "WaveNet: A Generative Model of Raw Audio":
https://arxiv.org/abs/1609.03499
## Code Authors
Alex Tamkin:
[
@atamkin
](
https://github.com/atamkin
)
Chris Shallue:
[
@cshallue
](
https://github.com/cshallue
)
## Pull Requests / Issues
Chris Shallue:
[
@cshallue
](
https://github.com/cshallue
)
## Additional Dependencies
This package requires TensorFlow 1.12 or greater. As of October 2018, this
requires the
**TensorFlow nightly build**
(
[
instructions
](
https://www.tensorflow.org/install/pip
)
).
In addition to the dependencies listed in the top-level README, this package
requires:
*
**TensorFlow Probability**
(
[
instructions
](
https://www.tensorflow.org/probability/install
)
)
*
**Six**
(
[
instructions
](
https://pypi.org/project/six/
)
)
## Basic Usage
To train a model on synthetic transits:
```
bash
bazel build astrowavenet/...
```
```
bash
bazel-bin/astrowavenet/trainer
\
--dataset
=
synthetic_transits
\
--model_dir
=
/tmp/astrowavenet/
\
--config_overrides
=
'{"hparams": {"batch_size": 16, "num_residual_blocks": 2}}'
\
--schedule
=
train_and_eval
\
--eval_steps
=
100
\
--save_checkpoints_steps
=
1000
```
research/astronet/astrowavenet/__init__.py
deleted
100644 → 0
View file @
93e0022d
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
research/astronet/astrowavenet/astrowavenet_model.py
deleted
100644 → 0
View file @
93e0022d
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A TensorFlow WaveNet model for generative modeling of light curves.
Implementation based on "WaveNet: A Generative Model of Raw Audio":
https://arxiv.org/abs/1609.03499
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
import
tensorflow_probability
as
tfp
def
_shift_right
(
x
):
"""Shifts the input Tensor right by one index along the second dimension.
Pads the front with zeros and discards the last element.
Args:
x: Input three-dimensional tf.Tensor.
Returns:
Padded, shifted tensor of same shape as input.
"""
x_padded
=
tf
.
pad
(
x
,
[[
0
,
0
],
[
1
,
0
],
[
0
,
0
]])
return
x_padded
[:,
:
-
1
,
:]
class
AstroWaveNet
(
object
):
"""A TensorFlow model for generative modeling of light curves."""
def
__init__
(
self
,
features
,
hparams
,
mode
):
"""Basic setup.
The actual TensorFlow graph is constructed in build().
Args:
features: A dictionary containing "autoregressive_input" and
"conditioning_stack", each of which is a named input Tensor. All
features have dtype float32 and shape [batch_size, length, dim].
hparams: A ConfigDict of hyperparameters for building the model.
mode: A tf.estimator.ModeKeys to specify whether the graph should be built
for training, evaluation or prediction.
Raises:
ValueError: If mode is invalid.
"""
valid_modes
=
[
tf
.
estimator
.
ModeKeys
.
TRAIN
,
tf
.
estimator
.
ModeKeys
.
EVAL
,
tf
.
estimator
.
ModeKeys
.
PREDICT
]
if
mode
not
in
valid_modes
:
raise
ValueError
(
"Expected mode in {}. Got: {}"
.
format
(
valid_modes
,
mode
))
self
.
hparams
=
hparams
self
.
mode
=
mode
self
.
autoregressive_input
=
features
[
"autoregressive_input"
]
self
.
conditioning_stack
=
features
[
"conditioning_stack"
]
self
.
weights
=
features
.
get
(
"weights"
)
self
.
network_output
=
None
# Sum of skip connections from dilation stack.
self
.
dist_params
=
None
# Dict of predicted distribution parameters.
self
.
predicted_distributions
=
None
# Predicted distribution for examples.
self
.
autoregressive_target
=
None
# Autoregressive target predictions.
self
.
batch_losses
=
None
# Loss for each predicted distribution in batch.
self
.
per_example_loss
=
None
# Loss for each example in batch.
self
.
num_nonzero_weight_examples
=
None
# Number of examples in batch.
self
.
total_loss
=
None
# Overall loss for the batch.
self
.
global_step
=
None
# Global step Tensor.
def
causal_conv_layer
(
self
,
x
,
output_size
,
kernel_width
,
dilation_rate
=
1
):
"""Applies a dialated causal convolution to the input.
Args:
x: tf.Tensor; Input tensor.
output_size: int; Number of output filters for the convolution.
kernel_width: int; Width of the 1D convolution window.
dilation_rate: int; Dilation rate of the layer.
Returns:
Resulting tf.Tensor after applying the convolution.
"""
causal_conv_op
=
tf
.
keras
.
layers
.
Conv1D
(
output_size
,
kernel_width
,
padding
=
"causal"
,
dilation_rate
=
dilation_rate
,
name
=
"causal_conv"
)
return
causal_conv_op
(
x
)
def
conv_1x1_layer
(
self
,
x
,
output_size
,
activation
=
None
):
"""Applies a 1x1 convolution to the input.
Args:
x: tf.Tensor; Input tensor.
output_size: int; Number of output filters for the 1x1 convolution.
activation: Activation function to apply (e.g. 'relu').
Returns:
Resulting tf.Tensor after applying the 1x1 convolution.
"""
conv_1x1_op
=
tf
.
keras
.
layers
.
Conv1D
(
output_size
,
1
,
activation
=
activation
,
name
=
"conv1x1"
)
return
conv_1x1_op
(
x
)
def
gated_residual_layer
(
self
,
x
,
dilation_rate
):
"""Creates a gated, dilated convolutional layer with a residual connnection.
Args:
x: tf.Tensor; Input tensor
dilation_rate: int; Dilation rate of the layer.
Returns:
skip_connection: tf.Tensor; Skip connection to network_output layer.
residual_connection: tf.Tensor; Sum of learned residual and input tensor.
"""
with
tf
.
variable_scope
(
"filter"
):
x_filter_conv
=
self
.
causal_conv_layer
(
x
,
x
.
shape
[
-
1
].
value
,
self
.
hparams
.
dilation_kernel_width
,
dilation_rate
)
cond_filter_conv
=
self
.
conv_1x1_layer
(
self
.
conditioning_stack
,
x
.
shape
[
-
1
].
value
)
with
tf
.
variable_scope
(
"gate"
):
x_gate_conv
=
self
.
causal_conv_layer
(
x
,
x
.
shape
[
-
1
].
value
,
self
.
hparams
.
dilation_kernel_width
,
dilation_rate
)
cond_gate_conv
=
self
.
conv_1x1_layer
(
self
.
conditioning_stack
,
x
.
shape
[
-
1
].
value
)
gated_activation
=
(
tf
.
tanh
(
x_filter_conv
+
cond_filter_conv
)
*
tf
.
sigmoid
(
x_gate_conv
+
cond_gate_conv
))
with
tf
.
variable_scope
(
"residual"
):
residual
=
self
.
conv_1x1_layer
(
gated_activation
,
x
.
shape
[
-
1
].
value
)
with
tf
.
variable_scope
(
"skip"
):
skip_connection
=
self
.
conv_1x1_layer
(
gated_activation
,
self
.
hparams
.
skip_output_dim
)
return
skip_connection
,
x
+
residual
def
build_network
(
self
):
"""Builds WaveNet network.
This consists of:
1) An initial causal convolution,
2) The dialation stack, and
3) Summing of skip connections
The network output can then be used to predict various output distributions.
Inputs:
self.autoregressive_input
self.conditioning_stack
Outputs:
self.network_output; tf.Tensor
"""
skip_connections
=
[]
x
=
_shift_right
(
self
.
autoregressive_input
)
with
tf
.
variable_scope
(
"preprocess"
):
x
=
self
.
causal_conv_layer
(
x
,
self
.
hparams
.
preprocess_output_size
,
self
.
hparams
.
preprocess_kernel_width
)
for
i
in
range
(
self
.
hparams
.
num_residual_blocks
):
with
tf
.
variable_scope
(
"block_{}"
.
format
(
i
)):
for
dilation_rate
in
self
.
hparams
.
dilation_rates
:
with
tf
.
variable_scope
(
"dilation_{}"
.
format
(
dilation_rate
)):
skip_connection
,
x
=
self
.
gated_residual_layer
(
x
,
dilation_rate
)
skip_connections
.
append
(
skip_connection
)
self
.
network_output
=
tf
.
add_n
(
skip_connections
)
def
dist_params_layer
(
self
,
x
,
outputs_size
):
"""Converts x to the correct shape for populating a distribution object.
Args:
x: A Tensor of shape [batch_size, time_series_length, num_features].
outputs_size: The number of parameters needed to specify all the
distributions in the output. E.g. 5*3=15 to specify 5 distributions with
3 parameters each.
Returns:
The parameters of each distribution, a tensor of shape [batch_size,
time_series_length, outputs_size].
"""
with
tf
.
variable_scope
(
"dist_params"
):
conv_outputs
=
self
.
conv_1x1_layer
(
x
,
outputs_size
)
return
conv_outputs
def
build_predictions
(
self
):
"""Predicts output distribution from network outputs.
Runs the model through:
1) ReLU
2) 1x1 convolution
3) ReLU
4) 1x1 convolution
The result of the last convolution is used as the parameters of the
specified output distribution (currently either Categorical or Normal).
Inputs:
self.network_outputs
Outputs:
self.dist_params
self.predicted_distributions
Raises:
ValueError: If distribution type is neither 'categorical' nor 'normal'.
"""
with
tf
.
variable_scope
(
"postprocess"
):
network_output
=
tf
.
keras
.
activations
.
relu
(
self
.
network_output
)
network_output
=
self
.
conv_1x1_layer
(
network_output
,
output_size
=
network_output
.
shape
[
-
1
].
value
,
activation
=
"relu"
)
num_dists
=
self
.
autoregressive_input
.
shape
[
-
1
].
value
if
self
.
hparams
.
output_distribution
.
type
==
"categorical"
:
num_classes
=
self
.
hparams
.
output_distribution
.
num_classes
logits
=
self
.
dist_params_layer
(
network_output
,
num_dists
*
num_classes
)
logits_shape
=
tf
.
concat
(
[
tf
.
shape
(
network_output
)[:
-
1
],
[
num_dists
,
num_classes
]],
0
)
logits
=
tf
.
reshape
(
logits
,
logits_shape
)
dist
=
tfp
.
distributions
.
Categorical
(
logits
=
logits
)
dist_params
=
{
"logits"
:
logits
}
elif
self
.
hparams
.
output_distribution
.
type
==
"normal"
:
loc_scale
=
self
.
dist_params_layer
(
network_output
,
num_dists
*
2
)
loc
,
scale
=
tf
.
split
(
loc_scale
,
2
,
axis
=-
1
)
# Ensure scale is positive.
scale
=
tf
.
nn
.
softplus
(
scale
)
+
self
.
hparams
.
output_distribution
.
min_scale
dist
=
tfp
.
distributions
.
Normal
(
loc
,
scale
)
dist_params
=
{
"loc"
:
loc
,
"scale"
:
scale
}
else
:
raise
ValueError
(
"Unsupported distribution type {}"
.
format
(
self
.
hparams
.
output_distribution
.
type
))
self
.
dist_params
=
dist_params
self
.
predicted_distributions
=
dist
def
build_losses
(
self
):
"""Builds the training losses.
Inputs:
self.predicted_distributions
Outputs:
self.batch_losses
self.total_loss
"""
autoregressive_target
=
self
.
autoregressive_input
# Quantize the target if the output distribution is categorical.
if
self
.
hparams
.
output_distribution
.
type
==
"categorical"
:
min_val
=
self
.
hparams
.
output_distribution
.
min_quantization_value
max_val
=
self
.
hparams
.
output_distribution
.
max_quantization_value
num_classes
=
self
.
hparams
.
output_distribution
.
num_classes
clipped_target
=
tf
.
keras
.
backend
.
clip
(
autoregressive_target
,
min_val
,
max_val
)
quantized_target
=
tf
.
floor
(
(
clipped_target
-
min_val
)
/
(
max_val
-
min_val
)
*
num_classes
)
# Deal with the corner case where clipped_target equals max_val by mapping
# the label num_classes to num_classes - 1. Essentially, this makes the
# final quantized bucket a closed interval while all the other quantized
# buckets are half-open intervals.
quantized_target
=
tf
.
where
(
quantized_target
>=
num_classes
,
tf
.
ones_like
(
quantized_target
)
*
(
num_classes
-
1
),
quantized_target
)
autoregressive_target
=
quantized_target
log_prob
=
self
.
predicted_distributions
.
log_prob
(
autoregressive_target
)
weights
=
self
.
weights
if
weights
is
None
:
weights
=
tf
.
ones_like
(
log_prob
)
weights_dim
=
len
(
weights
.
shape
)
per_example_weight
=
tf
.
reduce_sum
(
weights
,
axis
=
list
(
range
(
1
,
weights_dim
)))
per_example_indicator
=
tf
.
to_float
(
tf
.
greater
(
per_example_weight
,
0
))
num_examples
=
tf
.
reduce_sum
(
per_example_indicator
)
batch_losses
=
-
log_prob
*
weights
losses_ndims
=
batch_losses
.
shape
.
ndims
per_example_loss_sum
=
tf
.
reduce_sum
(
batch_losses
,
axis
=
list
(
range
(
1
,
losses_ndims
)))
per_example_loss
=
tf
.
where
(
per_example_weight
>
0
,
per_example_loss_sum
/
per_example_weight
,
tf
.
zeros_like
(
per_example_weight
))
total_loss
=
tf
.
reduce_sum
(
per_example_loss
)
/
num_examples
self
.
autoregressive_target
=
autoregressive_target
self
.
batch_losses
=
batch_losses
self
.
per_example_loss
=
per_example_loss
self
.
num_nonzero_weight_examples
=
num_examples
self
.
total_loss
=
total_loss
def
build
(
self
):
"""Creates all ops for training, evaluation or inference."""
self
.
global_step
=
tf
.
train
.
get_or_create_global_step
()
self
.
build_network
()
self
.
build_predictions
()
if
self
.
mode
in
[
tf
.
estimator
.
ModeKeys
.
TRAIN
,
tf
.
estimator
.
ModeKeys
.
EVAL
]:
self
.
build_losses
()
research/astronet/astrowavenet/astrowavenet_model_test.py
deleted
100644 → 0
View file @
93e0022d
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for astrowavenet."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
tensorflow
as
tf
from
astronet.util
import
configdict
from
astrowavenet
import
astrowavenet_model
class
AstrowavenetTest
(
tf
.
test
.
TestCase
):
def
assertShapeEquals
(
self
,
shape
,
tensor_or_array
):
"""Asserts that a Tensor or Numpy array has the expected shape.
Args:
shape: Numpy array or anything that can be converted to one.
tensor_or_array: tf.Tensor, tf.Variable, or Numpy array.
"""
if
isinstance
(
tensor_or_array
,
(
np
.
ndarray
,
np
.
generic
)):
self
.
assertAllEqual
(
shape
,
tensor_or_array
.
shape
)
elif
isinstance
(
tensor_or_array
,
(
tf
.
Tensor
,
tf
.
Variable
)):
self
.
assertAllEqual
(
shape
,
tensor_or_array
.
shape
.
as_list
())
else
:
raise
TypeError
(
"tensor_or_array must be a Tensor or Numpy ndarray"
)
def
test_build_model
(
self
):
time_series_length
=
9
input_num_features
=
8
context_num_features
=
7
input_placeholder
=
tf
.
placeholder
(
dtype
=
tf
.
float32
,
shape
=
[
None
,
time_series_length
,
input_num_features
],
name
=
"input"
)
context_placeholder
=
tf
.
placeholder
(
dtype
=
tf
.
float32
,
shape
=
[
None
,
time_series_length
,
context_num_features
],
name
=
"context"
)
features
=
{
"autoregressive_input"
:
input_placeholder
,
"conditioning_stack"
:
context_placeholder
}
mode
=
tf
.
estimator
.
ModeKeys
.
TRAIN
hparams
=
configdict
.
ConfigDict
({
"dilation_kernel_width"
:
2
,
"skip_output_dim"
:
6
,
"preprocess_output_size"
:
3
,
"preprocess_kernel_width"
:
5
,
"num_residual_blocks"
:
2
,
"dilation_rates"
:
[
1
,
2
,
4
],
"output_distribution"
:
{
"type"
:
"normal"
,
"min_scale"
:
0.001
,
}
})
model
=
astrowavenet_model
.
AstroWaveNet
(
features
,
hparams
,
mode
)
model
.
build
()
variables
=
{
v
.
op
.
name
:
v
for
v
in
tf
.
trainable_variables
()}
# Verify variable shapes in two residual blocks.
var
=
variables
[
"preprocess/causal_conv/kernel"
]
self
.
assertShapeEquals
((
5
,
8
,
3
),
var
)
var
=
variables
[
"preprocess/causal_conv/bias"
]
self
.
assertShapeEquals
((
3
,),
var
)
var
=
variables
[
"block_0/dilation_1/filter/causal_conv/kernel"
]
self
.
assertShapeEquals
((
2
,
3
,
3
),
var
)
var
=
variables
[
"block_0/dilation_1/filter/causal_conv/bias"
]
self
.
assertShapeEquals
((
3
,),
var
)
var
=
variables
[
"block_0/dilation_1/filter/conv1x1/kernel"
]
self
.
assertShapeEquals
((
1
,
7
,
3
),
var
)
var
=
variables
[
"block_0/dilation_1/filter/conv1x1/bias"
]
self
.
assertShapeEquals
((
3
,),
var
)
var
=
variables
[
"block_0/dilation_1/gate/causal_conv/kernel"
]
self
.
assertShapeEquals
((
2
,
3
,
3
),
var
)
var
=
variables
[
"block_0/dilation_1/gate/causal_conv/bias"
]
self
.
assertShapeEquals
((
3
,),
var
)
var
=
variables
[
"block_0/dilation_1/gate/conv1x1/kernel"
]
self
.
assertShapeEquals
((
1
,
7
,
3
),
var
)
var
=
variables
[
"block_0/dilation_1/gate/conv1x1/bias"
]
self
.
assertShapeEquals
((
3
,),
var
)
var
=
variables
[
"block_0/dilation_1/residual/conv1x1/kernel"
]
self
.
assertShapeEquals
((
1
,
3
,
3
),
var
)
var
=
variables
[
"block_0/dilation_1/residual/conv1x1/bias"
]
self
.
assertShapeEquals
((
3
,),
var
)
var
=
variables
[
"block_0/dilation_1/skip/conv1x1/kernel"
]
self
.
assertShapeEquals
((
1
,
3
,
6
),
var
)
var
=
variables
[
"block_0/dilation_1/skip/conv1x1/bias"
]
self
.
assertShapeEquals
((
6
,),
var
)
var
=
variables
[
"block_1/dilation_4/filter/causal_conv/kernel"
]
self
.
assertShapeEquals
((
2
,
3
,
3
),
var
)
var
=
variables
[
"block_1/dilation_4/filter/causal_conv/bias"
]
self
.
assertShapeEquals
((
3
,),
var
)
var
=
variables
[
"block_1/dilation_4/filter/conv1x1/kernel"
]
self
.
assertShapeEquals
((
1
,
7
,
3
),
var
)
var
=
variables
[
"block_1/dilation_4/filter/conv1x1/bias"
]
self
.
assertShapeEquals
((
3
,),
var
)
var
=
variables
[
"block_1/dilation_4/gate/causal_conv/kernel"
]
self
.
assertShapeEquals
((
2
,
3
,
3
),
var
)
var
=
variables
[
"block_1/dilation_4/gate/causal_conv/bias"
]
self
.
assertShapeEquals
((
3
,),
var
)
var
=
variables
[
"block_1/dilation_4/gate/conv1x1/kernel"
]
self
.
assertShapeEquals
((
1
,
7
,
3
),
var
)
var
=
variables
[
"block_1/dilation_4/gate/conv1x1/bias"
]
self
.
assertShapeEquals
((
3
,),
var
)
var
=
variables
[
"block_1/dilation_4/residual/conv1x1/kernel"
]
self
.
assertShapeEquals
((
1
,
3
,
3
),
var
)
var
=
variables
[
"block_1/dilation_4/residual/conv1x1/bias"
]
self
.
assertShapeEquals
((
3
,),
var
)
var
=
variables
[
"block_1/dilation_4/skip/conv1x1/kernel"
]
self
.
assertShapeEquals
((
1
,
3
,
6
),
var
)
var
=
variables
[
"block_1/dilation_4/skip/conv1x1/bias"
]
self
.
assertShapeEquals
((
6
,),
var
)
var
=
variables
[
"postprocess/conv1x1/kernel"
]
self
.
assertShapeEquals
((
1
,
6
,
6
),
var
)
var
=
variables
[
"postprocess/conv1x1/bias"
]
self
.
assertShapeEquals
((
6
,),
var
)
var
=
variables
[
"dist_params/conv1x1/kernel"
]
self
.
assertShapeEquals
((
1
,
6
,
16
),
var
)
var
=
variables
[
"dist_params/conv1x1/bias"
]
self
.
assertShapeEquals
((
16
,),
var
)
# Verify total number of trainable parameters.
num_preprocess_params
=
(
hparams
.
preprocess_kernel_width
*
input_num_features
*
hparams
.
preprocess_output_size
+
hparams
.
preprocess_output_size
)
num_gated_params
=
(
hparams
.
dilation_kernel_width
*
hparams
.
preprocess_output_size
*
hparams
.
preprocess_output_size
+
hparams
.
preprocess_output_size
+
1
*
context_num_features
*
hparams
.
preprocess_output_size
+
hparams
.
preprocess_output_size
)
*
2
num_residual_params
=
(
1
*
hparams
.
preprocess_output_size
*
hparams
.
preprocess_output_size
+
hparams
.
preprocess_output_size
)
num_skip_params
=
(
1
*
hparams
.
preprocess_output_size
*
hparams
.
skip_output_dim
+
hparams
.
skip_output_dim
)
num_block_params
=
(
num_gated_params
+
num_residual_params
+
num_skip_params
)
*
len
(
hparams
.
dilation_rates
)
*
hparams
.
num_residual_blocks
num_postprocess_params
=
(
1
*
hparams
.
skip_output_dim
*
hparams
.
skip_output_dim
+
hparams
.
skip_output_dim
)
num_dist_params
=
(
1
*
hparams
.
skip_output_dim
*
2
*
input_num_features
+
2
*
input_num_features
)
total_params
=
(
num_preprocess_params
+
num_block_params
+
num_postprocess_params
+
num_dist_params
)
total_retrieved_params
=
0
for
v
in
tf
.
trainable_variables
():
total_retrieved_params
+=
np
.
prod
(
v
.
shape
)
self
.
assertEqual
(
total_params
,
total_retrieved_params
)
# Verify model runs and outputs losses of correct shape.
scaffold
=
tf
.
train
.
Scaffold
()
scaffold
.
finalize
()
with
self
.
cached_session
()
as
sess
:
sess
.
run
([
scaffold
.
init_op
,
scaffold
.
local_init_op
])
step
=
sess
.
run
(
model
.
global_step
)
self
.
assertEqual
(
0
,
step
)
batch_size
=
11
feed_dict
=
{
input_placeholder
:
np
.
random
.
random
((
batch_size
,
time_series_length
,
input_num_features
)),
context_placeholder
:
np
.
random
.
random
((
batch_size
,
time_series_length
,
context_num_features
))
}
batch_losses
,
per_example_loss
,
total_loss
=
sess
.
run
(
[
model
.
batch_losses
,
model
.
per_example_loss
,
model
.
total_loss
],
feed_dict
=
feed_dict
)
self
.
assertShapeEquals
(
(
batch_size
,
time_series_length
,
input_num_features
),
batch_losses
)
self
.
assertShapeEquals
((
batch_size
,),
per_example_loss
)
self
.
assertShapeEquals
((),
total_loss
)
def
test_build_model_categorical
(
self
):
time_series_length
=
9
input_num_features
=
8
context_num_features
=
7
input_placeholder
=
tf
.
placeholder
(
dtype
=
tf
.
float32
,
shape
=
[
None
,
time_series_length
,
input_num_features
],
name
=
"input"
)
context_placeholder
=
tf
.
placeholder
(
dtype
=
tf
.
float32
,
shape
=
[
None
,
time_series_length
,
context_num_features
],
name
=
"context"
)
features
=
{
"autoregressive_input"
:
input_placeholder
,
"conditioning_stack"
:
context_placeholder
}
mode
=
tf
.
estimator
.
ModeKeys
.
TRAIN
hparams
=
configdict
.
ConfigDict
({
"dilation_kernel_width"
:
2
,
"skip_output_dim"
:
6
,
"preprocess_output_size"
:
3
,
"preprocess_kernel_width"
:
5
,
"num_residual_blocks"
:
2
,
"dilation_rates"
:
[
1
,
2
,
4
],
"output_distribution"
:
{
"type"
:
"categorical"
,
"num_classes"
:
256
,
"min_quantization_value"
:
-
1
,
"max_quantization_value"
:
1
}
})
model
=
astrowavenet_model
.
AstroWaveNet
(
features
,
hparams
,
mode
)
model
.
build
()
variables
=
{
v
.
op
.
name
:
v
for
v
in
tf
.
trainable_variables
()}
var
=
variables
[
"dist_params/conv1x1/kernel"
]
self
.
assertShapeEquals
(
(
1
,
hparams
.
skip_output_dim
,
hparams
.
output_distribution
.
num_classes
*
input_num_features
),
var
)
var
=
variables
[
"dist_params/conv1x1/bias"
]
self
.
assertShapeEquals
(
(
hparams
.
output_distribution
.
num_classes
*
input_num_features
,),
var
)
# Verify model runs and outputs losses of correct shape.
scaffold
=
tf
.
train
.
Scaffold
()
scaffold
.
finalize
()
with
self
.
cached_session
()
as
sess
:
sess
.
run
([
scaffold
.
init_op
,
scaffold
.
local_init_op
])
step
=
sess
.
run
(
model
.
global_step
)
self
.
assertEqual
(
0
,
step
)
batch_size
=
11
feed_dict
=
{
input_placeholder
:
np
.
random
.
random
((
batch_size
,
time_series_length
,
input_num_features
)),
context_placeholder
:
np
.
random
.
random
((
batch_size
,
time_series_length
,
context_num_features
))
}
batch_losses
,
per_example_loss
,
total_loss
=
sess
.
run
(
[
model
.
batch_losses
,
model
.
per_example_loss
,
model
.
total_loss
],
feed_dict
=
feed_dict
)
self
.
assertShapeEquals
(
(
batch_size
,
time_series_length
,
input_num_features
),
batch_losses
)
self
.
assertShapeEquals
((
batch_size
,),
per_example_loss
)
self
.
assertShapeEquals
((),
total_loss
)
def
test_output_normal
(
self
):
time_series_length
=
6
input_num_features
=
2
context_num_features
=
7
input_placeholder
=
tf
.
placeholder
(
dtype
=
tf
.
float32
,
shape
=
[
None
,
time_series_length
,
input_num_features
],
name
=
"input"
)
context_placeholder
=
tf
.
placeholder
(
dtype
=
tf
.
float32
,
shape
=
[
None
,
time_series_length
,
context_num_features
],
name
=
"context"
)
features
=
{
"autoregressive_input"
:
input_placeholder
,
"conditioning_stack"
:
context_placeholder
}
mode
=
tf
.
estimator
.
ModeKeys
.
TRAIN
hparams
=
configdict
.
ConfigDict
({
"dilation_kernel_width"
:
2
,
"skip_output_dim"
:
6
,
"preprocess_output_size"
:
3
,
"preprocess_kernel_width"
:
5
,
"num_residual_blocks"
:
2
,
"dilation_rates"
:
[
1
,
2
,
4
],
"output_distribution"
:
{
"type"
:
"normal"
,
"min_scale"
:
0
,
}
})
model
=
astrowavenet_model
.
AstroWaveNet
(
features
,
hparams
,
mode
)
model
.
build
()
# Model predicts loc and scale.
self
.
assertItemsEqual
([
"loc"
,
"scale"
],
model
.
dist_params
.
keys
())
self
.
assertShapeEquals
((
None
,
time_series_length
,
input_num_features
),
model
.
dist_params
[
"loc"
])
self
.
assertShapeEquals
((
None
,
time_series_length
,
input_num_features
),
model
.
dist_params
[
"scale"
])
scaffold
=
tf
.
train
.
Scaffold
()
scaffold
.
finalize
()
with
self
.
cached_session
()
as
sess
:
sess
.
run
([
scaffold
.
init_op
,
scaffold
.
local_init_op
])
step
=
sess
.
run
(
model
.
global_step
)
self
.
assertEqual
(
0
,
step
)
feed_dict
=
{
input_placeholder
:
[
[[
1
,
9
],
[
1
,
9
],
[
1
,
9
],
[
1
,
9
],
[
1
,
9
],
[
1
,
9
]],
[[
2
,
8
],
[
2
,
8
],
[
2
,
8
],
[
2
,
8
],
[
2
,
8
],
[
2
,
8
]],
],
# Context is not needed since we explicitly feed the dist params.
model
.
dist_params
[
"loc"
]:
[
[[
1
,
8
],
[
1
,
8
],
[
1
,
8
],
[
1
,
8
],
[
1
,
8
],
[
1
,
8
]],
[[
2
,
9
],
[
2
,
9
],
[
2
,
9
],
[
2
,
9
],
[
2
,
9
],
[
2
,
9
]],
],
model
.
dist_params
[
"scale"
]:
[
[[
0.1
,
0.1
],
[
0.2
,
0.2
],
[
0.5
,
0.5
],
[
1
,
1
],
[
2
,
2
],
[
5
,
5
]],
[[
0.1
,
0.1
],
[
0.2
,
0.2
],
[
0.5
,
0.5
],
[
1
,
1
],
[
2
,
2
],
[
5
,
5
]],
],
}
batch_losses
,
per_example_loss
,
num_examples
,
total_loss
=
sess
.
run
(
[
model
.
batch_losses
,
model
.
per_example_loss
,
model
.
num_nonzero_weight_examples
,
model
.
total_loss
],
feed_dict
=
feed_dict
)
np
.
testing
.
assert_array_almost_equal
(
[[[
-
1.38364656
,
48.61635344
],
[
-
0.69049938
,
11.80950062
],
[
0.22579135
,
2.22579135
],
[
0.91893853
,
1.41893853
],
[
1.61208571
,
1.73708571
],
[
2.52837645
,
2.54837645
]],
[[
-
1.38364656
,
48.61635344
],
[
-
0.69049938
,
11.80950062
],
[
0.22579135
,
2.22579135
],
[
0.91893853
,
1.41893853
],
[
1.61208571
,
1.73708571
],
[
2.52837645
,
2.54837645
]]],
batch_losses
)
np
.
testing
.
assert_array_almost_equal
([
5.96392435
,
5.96392435
],
per_example_loss
)
np
.
testing
.
assert_almost_equal
(
2
,
num_examples
)
np
.
testing
.
assert_almost_equal
(
5.96392435
,
total_loss
)
def
test_output_categorical
(
self
):
time_series_length
=
3
input_num_features
=
1
context_num_features
=
7
num_classes
=
4
# For quantized categorical output predictions.
input_placeholder
=
tf
.
placeholder
(
dtype
=
tf
.
float32
,
shape
=
[
None
,
time_series_length
,
input_num_features
],
name
=
"input"
)
context_placeholder
=
tf
.
placeholder
(
dtype
=
tf
.
float32
,
shape
=
[
None
,
time_series_length
,
context_num_features
],
name
=
"context"
)
features
=
{
"autoregressive_input"
:
input_placeholder
,
"conditioning_stack"
:
context_placeholder
}
mode
=
tf
.
estimator
.
ModeKeys
.
TRAIN
hparams
=
configdict
.
ConfigDict
({
"dilation_kernel_width"
:
2
,
"skip_output_dim"
:
6
,
"preprocess_output_size"
:
3
,
"preprocess_kernel_width"
:
5
,
"num_residual_blocks"
:
2
,
"dilation_rates"
:
[
1
,
2
,
4
],
"output_distribution"
:
{
"type"
:
"categorical"
,
"min_scale"
:
0
,
"num_classes"
:
num_classes
,
"min_quantization_value"
:
0
,
"max_quantization_value"
:
1
}
})
model
=
astrowavenet_model
.
AstroWaveNet
(
features
,
hparams
,
mode
)
model
.
build
()
self
.
assertItemsEqual
([
"logits"
],
model
.
dist_params
.
keys
())
self
.
assertShapeEquals
(
(
None
,
time_series_length
,
input_num_features
,
num_classes
),
model
.
dist_params
[
"logits"
])
scaffold
=
tf
.
train
.
Scaffold
()
scaffold
.
finalize
()
with
self
.
cached_session
()
as
sess
:
sess
.
run
([
scaffold
.
init_op
,
scaffold
.
local_init_op
])
step
=
sess
.
run
(
model
.
global_step
)
self
.
assertEqual
(
0
,
step
)
feed_dict
=
{
input_placeholder
:
[
[[
0
],
[
0
],
[
0
]],
# min_quantization_value
[[
0.2
],
[
0.2
],
[
0.2
]],
# Within bucket.
[[
0.25
],
[
0.25
],
[
0.25
]],
# On bucket boundary.
[[
0.5
],
[
0.5
],
[
0.5
]],
# On bucket boundary.
[[
0.8
],
[
0.8
],
[
0.8
]],
# Within bucket.
[[
1
],
[
1
],
[
1
]],
# max_quantization_value
[[
-
0.1
],
[
1.5
],
[
200
]],
# Outside range: will be clipped.
],
# Context is not needed since we explicitly feed the dist params.
model
.
dist_params
[
"logits"
]:
[
[[[
1
,
0
,
0
,
0
]],
[[
0
,
1
,
0
,
0
]],
[[
0
,
0
,
0
,
1
]]],
[[[
1
,
0
,
0
,
0
]],
[[
0
,
1
,
0
,
0
]],
[[
0
,
0
,
0
,
1
]]],
[[[
0
,
1
,
0
,
0
]],
[[
1
,
0
,
0
,
0
]],
[[
0
,
0
,
1
,
0
]]],
[[[
0
,
0
,
1
,
0
]],
[[
0
,
1
,
0
,
0
]],
[[
0
,
0
,
0
,
1
]]],
[[[
0
,
0
,
0
,
1
]],
[[
1
,
0
,
0
,
0
]],
[[
1
,
0
,
0
,
0
]]],
[[[
0
,
0
,
0
,
1
]],
[[
0
,
1
,
0
,
0
]],
[[
0
,
0
,
1
,
0
]]],
[[[
1
,
0
,
0
,
0
]],
[[
0
,
0
,
1
,
0
]],
[[
0
,
1
,
0
,
0
]]],
],
}
(
target
,
batch_losses
,
per_example_loss
,
num_examples
,
total_loss
)
=
sess
.
run
([
model
.
autoregressive_target
,
model
.
batch_losses
,
model
.
per_example_loss
,
model
.
num_nonzero_weight_examples
,
model
.
total_loss
],
feed_dict
=
feed_dict
)
np
.
testing
.
assert_array_almost_equal
([
[[
0
],
[
0
],
[
0
]],
[[
0
],
[
0
],
[
0
]],
[[
1
],
[
1
],
[
1
]],
[[
2
],
[
2
],
[
2
]],
[[
3
],
[
3
],
[
3
]],
[[
3
],
[
3
],
[
3
]],
[[
0
],
[
3
],
[
3
]],
],
target
)
np
.
testing
.
assert_array_almost_equal
([
[[
0.74366838
],
[
1.74366838
],
[
1.74366838
]],
[[
0.74366838
],
[
1.74366838
],
[
1.74366838
]],
[[
0.74366838
],
[
1.74366838
],
[
1.74366838
]],
[[
0.74366838
],
[
1.74366838
],
[
1.74366838
]],
[[
0.74366838
],
[
1.74366838
],
[
1.74366838
]],
[[
0.74366838
],
[
1.74366838
],
[
1.74366838
]],
[[
0.74366838
],
[
1.74366838
],
[
1.74366838
]],
],
batch_losses
)
np
.
testing
.
assert_array_almost_equal
([
1.41033504
,
1.41033504
,
1.41033504
,
1.41033504
,
1.41033504
,
1.41033504
,
1.41033504
],
per_example_loss
)
np
.
testing
.
assert_almost_equal
(
7
,
num_examples
)
np
.
testing
.
assert_almost_equal
(
1.41033504
,
total_loss
)
def
test_output_weighted
(
self
):
time_series_length
=
6
input_num_features
=
2
context_num_features
=
7
input_placeholder
=
tf
.
placeholder
(
dtype
=
tf
.
float32
,
shape
=
[
None
,
time_series_length
,
input_num_features
],
name
=
"input"
)
weights_placeholder
=
tf
.
placeholder
(
dtype
=
tf
.
float32
,
shape
=
[
None
,
time_series_length
,
input_num_features
],
name
=
"input"
)
context_placeholder
=
tf
.
placeholder
(
dtype
=
tf
.
float32
,
shape
=
[
None
,
time_series_length
,
context_num_features
],
name
=
"context"
)
features
=
{
"autoregressive_input"
:
input_placeholder
,
"weights"
:
weights_placeholder
,
"conditioning_stack"
:
context_placeholder
}
mode
=
tf
.
estimator
.
ModeKeys
.
TRAIN
hparams
=
configdict
.
ConfigDict
({
"dilation_kernel_width"
:
2
,
"skip_output_dim"
:
6
,
"preprocess_output_size"
:
3
,
"preprocess_kernel_width"
:
5
,
"num_residual_blocks"
:
2
,
"dilation_rates"
:
[
1
,
2
,
4
],
"output_distribution"
:
{
"type"
:
"normal"
,
"min_scale"
:
0
,
}
})
model
=
astrowavenet_model
.
AstroWaveNet
(
features
,
hparams
,
mode
)
model
.
build
()
scaffold
=
tf
.
train
.
Scaffold
()
scaffold
.
finalize
()
with
self
.
cached_session
()
as
sess
:
sess
.
run
([
scaffold
.
init_op
,
scaffold
.
local_init_op
])
step
=
sess
.
run
(
model
.
global_step
)
self
.
assertEqual
(
0
,
step
)
feed_dict
=
{
input_placeholder
:
[
[[
1
,
9
],
[
1
,
9
],
[
1
,
9
],
[
1
,
9
],
[
1
,
9
],
[
1
,
9
]],
[[
2
,
8
],
[
2
,
8
],
[
2
,
8
],
[
2
,
8
],
[
2
,
8
],
[
2
,
8
]],
[[
3
,
7
],
[
3
,
7
],
[
3
,
7
],
[
3
,
7
],
[
3
,
7
],
[
3
,
7
]],
],
weights_placeholder
:
[
[[
1
,
1
],
[
1
,
1
],
[
1
,
1
],
[
1
,
1
],
[
1
,
1
],
[
1
,
1
]],
[[
1
,
0
],
[
1
,
1
],
[
1
,
1
],
[
0
,
1
],
[
0
,
1
],
[
0
,
0
]],
[[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
]],
],
# Context is not needed since we explicitly feed the dist params.
model
.
dist_params
[
"loc"
]:
[
[[
1
,
8
],
[
1
,
8
],
[
1
,
8
],
[
1
,
8
],
[
1
,
8
],
[
1
,
8
]],
[[
2
,
9
],
[
2
,
9
],
[
2
,
9
],
[
2
,
9
],
[
2
,
9
],
[
2
,
9
]],
[[
3
,
6
],
[
3
,
6
],
[
3
,
6
],
[
3
,
6
],
[
3
,
6
],
[
3
,
6
]],
],
model
.
dist_params
[
"scale"
]:
[
[[
0.1
,
0.1
],
[
0.2
,
0.2
],
[
0.5
,
0.5
],
[
1
,
1
],
[
2
,
2
],
[
5
,
5
]],
[[
0.1
,
0.1
],
[
0.2
,
0.2
],
[
0.5
,
0.5
],
[
1
,
1
],
[
2
,
2
],
[
5
,
5
]],
[[
0.1
,
0.1
],
[
0.2
,
0.2
],
[
0.5
,
0.5
],
[
1
,
1
],
[
2
,
2
],
[
5
,
5
]],
],
}
batch_losses
,
per_example_loss
,
num_examples
,
total_loss
=
sess
.
run
(
[
model
.
batch_losses
,
model
.
per_example_loss
,
model
.
num_nonzero_weight_examples
,
model
.
total_loss
],
feed_dict
=
feed_dict
)
np
.
testing
.
assert_array_almost_equal
(
[[[
-
1.38364656
,
48.61635344
],
[
-
0.69049938
,
11.80950062
],
[
0.22579135
,
2.22579135
],
[
0.91893853
,
1.41893853
],
[
1.61208571
,
1.73708571
],
[
2.52837645
,
2.54837645
]],
[[
-
1.38364656
,
0
],
[
-
0.69049938
,
11.80950062
],
[
0.22579135
,
2.22579135
],
[
0
,
1.41893853
],
[
0
,
1.73708571
],
[
0
,
0
]],
[[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
]]],
batch_losses
)
np
.
testing
.
assert_array_almost_equal
([
5.96392435
,
2.19185166
,
0
],
per_example_loss
)
np
.
testing
.
assert_almost_equal
(
2
,
num_examples
)
np
.
testing
.
assert_almost_equal
(
4.07788801
,
total_loss
)
def
test_causality
(
self
):
time_series_length
=
7
input_num_features
=
1
context_num_features
=
1
input_placeholder
=
tf
.
placeholder
(
dtype
=
tf
.
float32
,
shape
=
[
None
,
time_series_length
,
input_num_features
],
name
=
"input"
)
context_placeholder
=
tf
.
placeholder
(
dtype
=
tf
.
float32
,
shape
=
[
None
,
time_series_length
,
context_num_features
],
name
=
"context"
)
features
=
{
"autoregressive_input"
:
input_placeholder
,
"conditioning_stack"
:
context_placeholder
}
mode
=
tf
.
estimator
.
ModeKeys
.
TRAIN
hparams
=
configdict
.
ConfigDict
({
"dilation_kernel_width"
:
1
,
"skip_output_dim"
:
1
,
"preprocess_output_size"
:
1
,
"preprocess_kernel_width"
:
1
,
"num_residual_blocks"
:
1
,
"dilation_rates"
:
[
1
],
"output_distribution"
:
{
"type"
:
"normal"
,
"min_scale"
:
0.001
,
}
})
model
=
astrowavenet_model
.
AstroWaveNet
(
features
,
hparams
,
mode
)
model
.
build
()
scaffold
=
tf
.
train
.
Scaffold
()
scaffold
.
finalize
()
with
self
.
cached_session
()
as
sess
:
sess
.
run
([
scaffold
.
init_op
,
scaffold
.
local_init_op
])
step
=
sess
.
run
(
model
.
global_step
)
self
.
assertEqual
(
0
,
step
)
feed_dict
=
{
input_placeholder
:
[
[[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
1
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
0
],
[
0
],
[
0
],
[
1
],
[
0
],
[
0
],
[
0
]],
[[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
1
]],
[[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
],
context_placeholder
:
[
[[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
1
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
0
],
[
0
],
[
0
],
[
1
],
[
0
],
[
0
],
[
0
]],
[[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
1
]],
],
}
network_output
=
sess
.
run
(
model
.
network_output
,
feed_dict
=
feed_dict
)
np
.
testing
.
assert_array_equal
(
[
[[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
# Input elements are used to predict the next timestamp.
[[
0
],
[
1
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
0
],
[
0
],
[
0
],
[
0
],
[
1
],
[
0
],
[
0
]],
[[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
# Context elements are used to predict the current timestamp.
[[
1
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
0
],
[
0
],
[
0
],
[
1
],
[
0
],
[
0
],
[
0
]],
[[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
1
]],
],
np
.
greater
(
np
.
abs
(
network_output
),
0
))
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
research/astronet/astrowavenet/configurations.py
deleted
100644 → 0
View file @
93e0022d
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configurations for model building, training and evaluation."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
def
base
():
"""Returns the base config for model building, training and evaluation."""
return
{
# Hyperparameters for building and training the model.
"hparams"
:
{
"batch_size"
:
64
,
"dilation_kernel_width"
:
2
,
"skip_output_dim"
:
10
,
"preprocess_output_size"
:
3
,
"preprocess_kernel_width"
:
10
,
"num_residual_blocks"
:
4
,
"dilation_rates"
:
[
1
,
2
,
4
,
8
,
16
],
"output_distribution"
:
{
"type"
:
"normal"
,
"min_scale"
:
0.001
},
# Learning rate parameters.
"learning_rate"
:
1e-6
,
"learning_rate_decay_steps"
:
0
,
"learning_rate_decay_factor"
:
0
,
"learning_rate_decay_staircase"
:
True
,
# Optimizer for training the model.
"optimizer"
:
"adam"
,
# If not None, gradient norms will be clipped to this value.
"clip_gradient_norm"
:
1
,
}
}
def
categorical
():
"""Returns a config for models with a categorical output distribution.
Input values will be clipped to {min,max}_value_for_quantization, then
linearly split into num_classes.
"""
config
=
base
()
config
[
"hparams"
][
"output_distribution"
]
=
{
"type"
:
"categorical"
,
"num_classes"
:
256
,
"min_quantization_value"
:
-
1
,
"max_quantization_value"
:
1
}
return
config
def
get_config
(
config_name
):
"""Returns config correspnding to provided name."""
if
config_name
in
[
"base"
,
"normal"
]:
return
base
()
elif
config_name
==
"categorical"
:
return
categorical
()
else
:
raise
ValueError
(
"Unrecognized config name: {}"
.
format
(
config_name
))
research/astronet/astrowavenet/data/BUILD
deleted
100644 → 0
View file @
93e0022d
package
(
default_visibility
=
[
"//visibility:public"
])
licenses
([
"notice"
])
# Apache 2.0
py_library
(
name
=
"base"
,
srcs
=
[
"base.py"
,
],
deps
=
[
"//astronet/ops:dataset_ops"
,
"//astronet/util:configdict"
,
],
)
py_test
(
name
=
"base_test"
,
srcs
=
[
"base_test.py"
],
data
=
[
"test_data/test-dataset.tfrecord"
],
srcs_version
=
"PY2AND3"
,
deps
=
[
":base"
],
)
py_library
(
name
=
"kepler_light_curves"
,
srcs
=
[
"kepler_light_curves.py"
,
],
deps
=
[
":base"
,
"//astronet/util:configdict"
,
],
)
py_library
(
name
=
"synthetic_transits"
,
srcs
=
[
"synthetic_transits.py"
,
],
deps
=
[
":base"
,
":synthetic_transit_maker"
,
"//astronet/util:configdict"
,
],
)
py_library
(
name
=
"synthetic_transit_maker"
,
srcs
=
[
"synthetic_transit_maker.py"
,
],
)
py_test
(
name
=
"synthetic_transit_maker_test"
,
srcs
=
[
"synthetic_transit_maker_test.py"
],
srcs_version
=
"PY2AND3"
,
deps
=
[
":synthetic_transit_maker"
],
)
research/astronet/astrowavenet/data/__init__.py
deleted
100644 → 0
View file @
93e0022d
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
research/astronet/astrowavenet/data/base.py
deleted
100644 → 0
View file @
93e0022d
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base dataset builder classes for AstroWaveNet input pipelines."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
abc
import
six
import
tensorflow
as
tf
from
astronet.util
import
configdict
from
astronet.ops
import
dataset_ops
@
six
.
add_metaclass
(
abc
.
ABCMeta
)
class
DatasetBuilder
(
object
):
"""Base class for building a dataset input pipeline for AstroWaveNet."""
def
__init__
(
self
,
config_overrides
=
None
):
"""Initializes the dataset builder.
Args:
config_overrides: Dict or ConfigDict containing overrides to the default
configuration.
"""
self
.
config
=
configdict
.
ConfigDict
(
self
.
default_config
())
if
config_overrides
is
not
None
:
self
.
config
.
update
(
config_overrides
)
@
staticmethod
def
default_config
():
"""Returns the default configuration as a ConfigDict or Python dict."""
return
{}
@
abc
.
abstractmethod
def
build
(
self
,
batch_size
):
"""Builds the dataset input pipeline.
Args:
batch_size: The number of input examples in each batch.
Returns:
A tf.data.Dataset object.
"""
raise
NotImplementedError
@
six
.
add_metaclass
(
abc
.
ABCMeta
)
class
_ShardedDatasetBuilder
(
DatasetBuilder
):
"""Abstract base class for a dataset consisting of sharded files."""
def
__init__
(
self
,
file_pattern
,
mode
,
config_overrides
=
None
,
use_tpu
=
False
):
"""Initializes the dataset builder.
Args:
file_pattern: File pattern matching input file shards, e.g.
"/tmp/train-?????-of-00100". May also be a comma-separated list of file
patterns.
mode: A tf.estimator.ModeKeys.
config_overrides: Dict or ConfigDict containing overrides to the default
configuration.
use_tpu: Whether to build the dataset for TPU.
"""
super
(
_ShardedDatasetBuilder
,
self
).
__init__
(
config_overrides
)
self
.
file_pattern
=
file_pattern
self
.
mode
=
mode
self
.
use_tpu
=
use_tpu
@
staticmethod
def
default_config
():
config
=
super
(
_ShardedDatasetBuilder
,
_ShardedDatasetBuilder
).
default_config
()
config
.
update
({
"max_length"
:
1024
,
"shuffle_values_buffer"
:
1000
,
"num_parallel_parser_calls"
:
4
,
"batches_buffer_size"
:
None
,
# Defaults to max(1, 256 / batch_size).
})
return
config
@
abc
.
abstractmethod
def
file_reader
(
self
):
"""Returns a function that reads a single sharded file."""
raise
NotImplementedError
@
abc
.
abstractmethod
def
create_example_parser
(
self
):
"""Returns a function that parses a single tf.Example proto."""
raise
NotImplementedError
def
_batch_and_pad
(
self
,
dataset
,
batch_size
):
"""Combines elements into batches of the same length, padding if needed."""
if
self
.
use_tpu
:
padded_length
=
self
.
config
.
max_length
if
not
padded_length
:
raise
ValueError
(
"config.max_length is required when using TPU"
)
# Pad with zeros up to padded_length. Note that this will pad the
# "weights" Tensor with zeros as well, which ensures that padded elements
# do not contribute to the loss.
padded_shapes
=
{}
for
name
,
shape
in
dataset
.
output_shapes
.
iteritems
():
shape
.
assert_is_compatible_with
([
None
,
None
])
# Expect a 2D sequence.
dims
=
shape
.
as_list
()
dims
[
0
]
=
padded_length
shape
=
tf
.
TensorShape
(
dims
)
shape
.
assert_is_fully_defined
()
padded_shapes
[
name
]
=
shape
else
:
# Pad each batch up to the maximum size of each dimension in the batch.
padded_shapes
=
dataset
.
output_shapes
return
dataset
.
padded_batch
(
batch_size
,
padded_shapes
)
def
build
(
self
,
batch_size
):
"""Builds the dataset input pipeline.
Args:
batch_size:
Returns:
A tf.data.Dataset.
Raises:
ValueError: If no files match self.file_pattern.
"""
file_patterns
=
self
.
file_pattern
.
split
(
","
)
filenames
=
[]
for
p
in
file_patterns
:
matches
=
tf
.
gfile
.
Glob
(
p
)
if
not
matches
:
raise
ValueError
(
"Found no input files matching {}"
.
format
(
p
))
filenames
.
extend
(
matches
)
tf
.
logging
.
info
(
"Building input pipeline from %d files matching patterns: %s"
,
len
(
filenames
),
file_patterns
)
is_training
=
self
.
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
# Create a string dataset of filenames, and possibly shuffle.
filename_dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
filenames
)
if
is_training
and
len
(
filenames
)
>
1
:
filename_dataset
=
filename_dataset
.
shuffle
(
len
(
filenames
))
# Read serialized Example protos.
dataset
=
filename_dataset
.
apply
(
tf
.
contrib
.
data
.
parallel_interleave
(
self
.
file_reader
(),
cycle_length
=
8
,
block_length
=
8
,
sloppy
=
True
))
if
is_training
:
# Shuffle and repeat. Note that shuffle() is before repeat(), so elements
# are shuffled among each epoch of data, and not between epochs of data.
if
self
.
config
.
shuffle_values_buffer
>
0
:
dataset
=
dataset
.
shuffle
(
self
.
config
.
shuffle_values_buffer
)
dataset
=
dataset
.
repeat
()
# Map the parser over the dataset.
dataset
=
dataset
.
map
(
self
.
create_example_parser
(),
num_parallel_calls
=
self
.
config
.
num_parallel_parser_calls
)
def
_prepare_wavenet_inputs
(
features
):
"""Validates features, and clips lengths and adds weights if needed."""
# Validate feature names.
required_features
=
{
"autoregressive_input"
,
"conditioning_stack"
}
allowed_features
=
required_features
|
{
"weights"
}
feature_names
=
features
.
keys
()
if
not
required_features
.
issubset
(
feature_names
):
raise
ValueError
(
"Features must contain all of: {}. Got: {}"
.
format
(
required_features
,
feature_names
))
if
not
allowed_features
.
issuperset
(
feature_names
):
raise
ValueError
(
"Features can only contain: {}. Got: {}"
.
format
(
allowed_features
,
feature_names
))
output
=
{}
for
name
,
value
in
features
.
items
():
# Validate shapes. The output dimension is [num_samples, dim].
ndims
=
len
(
value
.
shape
)
if
ndims
==
1
:
# Add an extra dimension: [num_samples] -> [num_samples, 1].
value
=
tf
.
expand_dims
(
value
,
-
1
)
elif
ndims
!=
2
:
raise
ValueError
(
"Features should be 1D or 2D sequences. Got '{}' = {}"
.
format
(
name
,
value
))
if
self
.
config
.
max_length
:
value
=
value
[:
self
.
config
.
max_length
]
output
[
name
]
=
value
if
"weights"
not
in
output
:
output
[
"weights"
]
=
tf
.
ones_like
(
output
[
"autoregressive_input"
])
return
output
dataset
=
dataset
.
map
(
_prepare_wavenet_inputs
)
# Batch results by up to batch_size.
dataset
=
self
.
_batch_and_pad
(
dataset
,
batch_size
)
if
is_training
:
# The dataset repeats infinitely before batching, so each batch has the
# maximum number of elements.
dataset
=
dataset_ops
.
set_batch_size
(
dataset
,
batch_size
)
elif
self
.
use_tpu
and
self
.
mode
==
tf
.
estimator
.
ModeKeys
.
EVAL
:
# Pad to ensure that each batch has the same number of elements.
dataset
=
dataset_ops
.
pad_dataset_to_batch_size
(
dataset
,
batch_size
)
# Prefetch batches.
buffer_size
=
(
self
.
config
.
batches_buffer_size
or
max
(
1
,
int
(
256
/
batch_size
)))
dataset
=
dataset
.
prefetch
(
buffer_size
)
return
dataset
def
tfrecord_reader
(
filename
):
"""Returns a tf.data.Dataset that reads a single TFRecord file shard."""
return
tf
.
data
.
TFRecordDataset
(
filename
,
buffer_size
=
16
*
1000
*
1000
)
class
TFRecordDataset
(
_ShardedDatasetBuilder
):
"""Builder for a dataset consisting of TFRecord files."""
def
file_reader
(
self
):
"""Returns a function that reads a single file shard."""
return
tfrecord_reader
research/astronet/astrowavenet/data/base_test.py
deleted
100644 → 0
View file @
93e0022d
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for base.py."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os.path
from
absl
import
flags
import
numpy
as
np
import
tensorflow
as
tf
from
astrowavenet.data
import
base
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_string
(
"test_srcdir"
,
""
,
"Test source directory."
)
_TEST_TFRECORD_FILE
=
"astrowavenet/data/test_data/test-dataset.tfrecord"
class
TFRecordDataset
(
base
.
TFRecordDataset
):
"""Concrete subclass of TFRecordDataset for testing."""
@
staticmethod
def
default_config
():
config
=
super
(
TFRecordDataset
,
TFRecordDataset
).
default_config
()
config
.
update
({
"shuffle_values_buffer"
:
0
,
# Ensure deterministic output.
"input_dim"
:
1
,
"conditioning_dim"
:
1
,
"include_weights"
:
False
,
})
return
config
def
create_example_parser
(
self
):
"""Returns a function that parses a single tf.Example proto."""
def
_example_parser
(
serialized_example
):
"""Parses a single tf.Example into feature and label Tensors."""
features
=
tf
.
parse_single_example
(
serialized_example
,
features
=
{
"feature_1"
:
tf
.
VarLenFeature
(
tf
.
float32
),
"feature_2"
:
tf
.
VarLenFeature
(
tf
.
float32
),
"feature_3"
:
tf
.
VarLenFeature
(
tf
.
float32
),
"feature_4"
:
tf
.
VarLenFeature
(
tf
.
float32
),
"weights"
:
tf
.
VarLenFeature
(
tf
.
float32
),
})
output
=
{}
if
self
.
config
.
input_dim
==
1
:
# Shape = [num_samples].
output
[
"autoregressive_input"
]
=
features
[
"feature_1"
].
values
elif
self
.
config
.
input_dim
==
2
:
# Shape = [num_samples, 2].
output
[
"autoregressive_input"
]
=
tf
.
stack
(
[
features
[
"feature_1"
].
values
,
features
[
"feature_2"
].
values
],
axis
=-
1
)
else
:
raise
ValueError
(
"Unexpected input_dim: {}"
.
format
(
self
.
config
.
input_dim
))
if
self
.
config
.
conditioning_dim
==
1
:
# Shape = [num_samples].
output
[
"conditioning_stack"
]
=
features
[
"feature_3"
].
values
elif
self
.
config
.
conditioning_dim
==
2
:
# Shape = [num_samples, 2].
output
[
"conditioning_stack"
]
=
tf
.
stack
(
[
features
[
"feature_3"
].
values
,
features
[
"feature_4"
].
values
],
axis
=-
1
)
else
:
raise
ValueError
(
"Unexpected conditioning_dim: {}"
.
format
(
self
.
config
.
conditioning_dim
))
if
self
.
config
.
include_weights
:
output
[
"weights"
]
=
features
[
"weights"
].
values
return
output
return
_example_parser
class
TFRecordDatasetTest
(
tf
.
test
.
TestCase
):
def
setUp
(
self
):
super
(
TFRecordDatasetTest
,
self
).
setUp
()
# The test dataset contains 8 tensorflow.Example protocol buffers. The i-th
# Example contains the following features:
# feature_1 = range(10, 10 + i + 1)
# feature_2 = range(20, 20 + i + 1)
# feature_3 = range(30, 30 + i + 1)
# feature_4 = range(40, 40 + i + 1)
# weights = [0] * i + [1]
self
.
_file_pattern
=
os
.
path
.
join
(
FLAGS
.
test_srcdir
,
_TEST_TFRECORD_FILE
)
def
testTrainMode
(
self
):
builder
=
TFRecordDataset
(
self
.
_file_pattern
,
tf
.
estimator
.
ModeKeys
.
TRAIN
)
next_features
=
builder
.
build
(
5
).
make_one_shot_iterator
().
get_next
()
self
.
assertItemsEqual
(
[
"autoregressive_input"
,
"conditioning_stack"
,
"weights"
],
next_features
.
keys
())
# Features have dynamic length but fixed batch size and input dimension.
next_features
[
"autoregressive_input"
].
shape
.
assert_is_compatible_with
(
[
5
,
None
,
1
])
next_features
[
"conditioning_stack"
].
shape
.
assert_is_compatible_with
(
[
5
,
None
,
1
])
next_features
[
"weights"
].
shape
.
assert_is_compatible_with
([
5
,
1
,
None
])
# Dataset repeats indefinitely.
with
self
.
test_session
()
as
sess
:
features
=
sess
.
run
(
next_features
)
np
.
testing
.
assert_almost_equal
([
[[
10
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
0
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
14
]],
],
features
[
"autoregressive_input"
])
np
.
testing
.
assert_almost_equal
([
[[
30
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
0
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
34
]],
],
features
[
"conditioning_stack"
])
np
.
testing
.
assert_almost_equal
([
[[
1
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
1
],
[
1
],
[
0
],
[
0
],
[
0
]],
[[
1
],
[
1
],
[
1
],
[
0
],
[
0
]],
[[
1
],
[
1
],
[
1
],
[
1
],
[
0
]],
[[
1
],
[
1
],
[
1
],
[
1
],
[
1
]],
],
features
[
"weights"
])
features
=
sess
.
run
(
next_features
)
np
.
testing
.
assert_almost_equal
([
[[
10
],
[
11
],
[
12
],
[
13
],
[
14
],
[
15
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
14
],
[
15
],
[
16
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
14
],
[
15
],
[
16
],
[
17
]],
[[
10
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
],
features
[
"autoregressive_input"
])
np
.
testing
.
assert_almost_equal
([
[[
30
],
[
31
],
[
32
],
[
33
],
[
34
],
[
35
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
34
],
[
35
],
[
36
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
34
],
[
35
],
[
36
],
[
37
]],
[[
30
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
],
features
[
"conditioning_stack"
])
np
.
testing
.
assert_almost_equal
([
[[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
0
],
[
0
]],
[[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
0
]],
[[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
1
]],
[[
1
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
1
],
[
1
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
],
features
[
"weights"
])
features
=
sess
.
run
(
next_features
)
np
.
testing
.
assert_almost_equal
([
[[
10
],
[
11
],
[
12
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
0
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
14
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
14
],
[
15
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
14
],
[
15
],
[
16
]],
],
features
[
"autoregressive_input"
])
np
.
testing
.
assert_almost_equal
([
[[
30
],
[
31
],
[
32
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
0
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
34
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
34
],
[
35
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
34
],
[
35
],
[
36
]],
],
features
[
"conditioning_stack"
])
np
.
testing
.
assert_almost_equal
([
[[
1
],
[
1
],
[
1
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
1
],
[
1
],
[
1
],
[
1
],
[
0
],
[
0
],
[
0
]],
[[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
0
],
[
0
]],
[[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
0
]],
[[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
1
]],
],
features
[
"weights"
])
def
testTrainModeReadWeights
(
self
):
config_overrides
=
{
"include_weights"
:
True
}
builder
=
TFRecordDataset
(
self
.
_file_pattern
,
tf
.
estimator
.
ModeKeys
.
TRAIN
,
config_overrides
=
config_overrides
)
next_features
=
builder
.
build
(
5
).
make_one_shot_iterator
().
get_next
()
self
.
assertItemsEqual
(
[
"autoregressive_input"
,
"conditioning_stack"
,
"weights"
],
next_features
.
keys
())
# Features have dynamic length but fixed batch size and input dimension.
next_features
[
"autoregressive_input"
].
shape
.
assert_is_compatible_with
(
[
5
,
None
,
1
])
next_features
[
"conditioning_stack"
].
shape
.
assert_is_compatible_with
(
[
5
,
None
,
1
])
next_features
[
"weights"
].
shape
.
assert_is_compatible_with
([
5
,
None
,
1
])
# Dataset repeats indefinitely.
with
self
.
test_session
()
as
sess
:
features
=
sess
.
run
(
next_features
)
np
.
testing
.
assert_almost_equal
([
[[
10
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
0
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
14
]],
],
features
[
"autoregressive_input"
])
np
.
testing
.
assert_almost_equal
([
[[
30
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
0
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
34
]],
],
features
[
"conditioning_stack"
])
np
.
testing
.
assert_almost_equal
([
[[
1
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
0
],
[
1
],
[
0
],
[
0
],
[
0
]],
[[
0
],
[
0
],
[
1
],
[
0
],
[
0
]],
[[
0
],
[
0
],
[
0
],
[
1
],
[
0
]],
[[
0
],
[
0
],
[
0
],
[
0
],
[
1
]],
],
features
[
"weights"
])
features
=
sess
.
run
(
next_features
)
np
.
testing
.
assert_almost_equal
([
[[
10
],
[
11
],
[
12
],
[
13
],
[
14
],
[
15
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
14
],
[
15
],
[
16
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
14
],
[
15
],
[
16
],
[
17
]],
[[
10
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
],
features
[
"autoregressive_input"
])
np
.
testing
.
assert_almost_equal
([
[[
30
],
[
31
],
[
32
],
[
33
],
[
34
],
[
35
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
34
],
[
35
],
[
36
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
34
],
[
35
],
[
36
],
[
37
]],
[[
30
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
],
features
[
"conditioning_stack"
])
np
.
testing
.
assert_almost_equal
([
[[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
1
],
[
0
],
[
0
]],
[[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
1
],
[
0
]],
[[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
1
]],
[[
1
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
0
],
[
1
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
],
features
[
"weights"
])
features
=
sess
.
run
(
next_features
)
np
.
testing
.
assert_almost_equal
([
[[
10
],
[
11
],
[
12
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
0
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
14
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
14
],
[
15
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
14
],
[
15
],
[
16
]],
],
features
[
"autoregressive_input"
])
np
.
testing
.
assert_almost_equal
([
[[
30
],
[
31
],
[
32
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
0
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
34
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
34
],
[
35
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
34
],
[
35
],
[
36
]],
],
features
[
"conditioning_stack"
])
np
.
testing
.
assert_almost_equal
([
[[
0
],
[
0
],
[
1
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
0
],
[
0
],
[
0
],
[
1
],
[
0
],
[
0
],
[
0
]],
[[
0
],
[
0
],
[
0
],
[
0
],
[
1
],
[
0
],
[
0
]],
[[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
1
],
[
0
]],
[[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
1
]],
],
features
[
"weights"
])
def
testTrainMode2DInput
(
self
):
config_overrides
=
{
"input_dim"
:
2
}
builder
=
TFRecordDataset
(
self
.
_file_pattern
,
tf
.
estimator
.
ModeKeys
.
TRAIN
,
config_overrides
=
config_overrides
)
next_features
=
builder
.
build
(
5
).
make_one_shot_iterator
().
get_next
()
self
.
assertItemsEqual
(
[
"autoregressive_input"
,
"conditioning_stack"
,
"weights"
],
next_features
.
keys
())
# Features have dynamic length but fixed batch size and input dimension.
next_features
[
"autoregressive_input"
].
shape
.
assert_is_compatible_with
(
[
5
,
None
,
2
])
next_features
[
"conditioning_stack"
].
shape
.
assert_is_compatible_with
(
[
5
,
None
,
1
])
next_features
[
"weights"
].
shape
.
assert_is_compatible_with
([
5
,
1
,
None
])
# Dataset repeats indefinitely.
with
self
.
test_session
()
as
sess
:
features
=
sess
.
run
(
next_features
)
np
.
testing
.
assert_almost_equal
([
[[
10
,
20
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
]],
[[
10
,
20
],
[
11
,
21
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
]],
[[
10
,
20
],
[
11
,
21
],
[
12
,
22
],
[
0
,
0
],
[
0
,
0
]],
[[
10
,
20
],
[
11
,
21
],
[
12
,
22
],
[
13
,
23
],
[
0
,
0
]],
[[
10
,
20
],
[
11
,
21
],
[
12
,
22
],
[
13
,
23
],
[
14
,
24
]],
],
features
[
"autoregressive_input"
])
np
.
testing
.
assert_almost_equal
([
[[
30
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
0
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
34
]],
],
features
[
"conditioning_stack"
])
np
.
testing
.
assert_almost_equal
([
[[
1
,
1
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
]],
[[
1
,
1
],
[
1
,
1
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
]],
[[
1
,
1
],
[
1
,
1
],
[
1
,
1
],
[
0
,
0
],
[
0
,
0
]],
[[
1
,
1
],
[
1
,
1
],
[
1
,
1
],
[
1
,
1
],
[
0
,
0
]],
[[
1
,
1
],
[
1
,
1
],
[
1
,
1
],
[
1
,
1
],
[
1
,
1
]],
],
features
[
"weights"
])
features
=
sess
.
run
(
next_features
)
np
.
testing
.
assert_almost_equal
([
[[
10
,
20
],
[
11
,
21
],
[
12
,
22
],
[
13
,
23
],
[
14
,
24
],
[
15
,
25
],
[
0
,
0
],
[
0
,
0
]],
[[
10
,
20
],
[
11
,
21
],
[
12
,
22
],
[
13
,
23
],
[
14
,
24
],
[
15
,
25
],
[
16
,
26
],
[
0
,
0
]],
[[
10
,
20
],
[
11
,
21
],
[
12
,
22
],
[
13
,
23
],
[
14
,
24
],
[
15
,
25
],
[
16
,
26
],
[
17
,
27
]],
[[
10
,
20
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
]],
[[
10
,
20
],
[
11
,
21
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
]],
],
features
[
"autoregressive_input"
])
np
.
testing
.
assert_almost_equal
([
[[
30
],
[
31
],
[
32
],
[
33
],
[
34
],
[
35
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
34
],
[
35
],
[
36
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
34
],
[
35
],
[
36
],
[
37
]],
[[
30
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
],
features
[
"conditioning_stack"
])
np
.
testing
.
assert_almost_equal
([
[[
1
,
1
],
[
1
,
1
],
[
1
,
1
],
[
1
,
1
],
[
1
,
1
],
[
1
,
1
],
[
0
,
0
],
[
0
,
0
]],
[[
1
,
1
],
[
1
,
1
],
[
1
,
1
],
[
1
,
1
],
[
1
,
1
],
[
1
,
1
],
[
1
,
1
],
[
0
,
0
]],
[[
1
,
1
],
[
1
,
1
],
[
1
,
1
],
[
1
,
1
],
[
1
,
1
],
[
1
,
1
],
[
1
,
1
],
[
1
,
1
]],
[[
1
,
1
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
]],
[[
1
,
1
],
[
1
,
1
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
]],
],
features
[
"weights"
])
features
=
sess
.
run
(
next_features
)
np
.
testing
.
assert_almost_equal
([
[[
10
,
20
],
[
11
,
21
],
[
12
,
22
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
]],
[[
10
,
20
],
[
11
,
21
],
[
12
,
22
],
[
13
,
23
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
]],
[[
10
,
20
],
[
11
,
21
],
[
12
,
22
],
[
13
,
23
],
[
14
,
24
],
[
0
,
0
],
[
0
,
0
]],
[[
10
,
20
],
[
11
,
21
],
[
12
,
22
],
[
13
,
23
],
[
14
,
24
],
[
15
,
25
],
[
0
,
0
]],
[[
10
,
20
],
[
11
,
21
],
[
12
,
22
],
[
13
,
23
],
[
14
,
24
],
[
15
,
25
],
[
16
,
26
]
],
],
features
[
"autoregressive_input"
])
np
.
testing
.
assert_almost_equal
([
[[
30
],
[
31
],
[
32
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
0
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
34
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
34
],
[
35
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
34
],
[
35
],
[
36
]],
],
features
[
"conditioning_stack"
])
np
.
testing
.
assert_almost_equal
([
[[
1
,
1
],
[
1
,
1
],
[
1
,
1
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
]],
[[
1
,
1
],
[
1
,
1
],
[
1
,
1
],
[
1
,
1
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
]],
[[
1
,
1
],
[
1
,
1
],
[
1
,
1
],
[
1
,
1
],
[
1
,
1
],
[
0
,
0
],
[
0
,
0
]],
[[
1
,
1
],
[
1
,
1
],
[
1
,
1
],
[
1
,
1
],
[
1
,
1
],
[
1
,
1
],
[
0
,
0
]],
[[
1
,
1
],
[
1
,
1
],
[
1
,
1
],
[
1
,
1
],
[
1
,
1
],
[
1
,
1
],
[
1
,
1
]],
],
features
[
"weights"
])
def
testTrainMode2DConditioning
(
self
):
config_overrides
=
{
"conditioning_dim"
:
2
}
builder
=
TFRecordDataset
(
self
.
_file_pattern
,
tf
.
estimator
.
ModeKeys
.
TRAIN
,
config_overrides
=
config_overrides
)
next_features
=
builder
.
build
(
5
).
make_one_shot_iterator
().
get_next
()
self
.
assertItemsEqual
(
[
"autoregressive_input"
,
"conditioning_stack"
,
"weights"
],
next_features
.
keys
())
# Features have dynamic length but fixed batch size and input dimension.
next_features
[
"autoregressive_input"
].
shape
.
assert_is_compatible_with
(
[
5
,
None
,
1
])
next_features
[
"conditioning_stack"
].
shape
.
assert_is_compatible_with
(
[
5
,
None
,
2
])
next_features
[
"weights"
].
shape
.
assert_is_compatible_with
([
5
,
1
,
None
])
# Dataset repeats indefinitely.
with
self
.
test_session
()
as
sess
:
features
=
sess
.
run
(
next_features
)
np
.
testing
.
assert_almost_equal
([
[[
10
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
0
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
14
]],
],
features
[
"autoregressive_input"
])
np
.
testing
.
assert_almost_equal
([
[[
30
,
40
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
]],
[[
30
,
40
],
[
31
,
41
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
]],
[[
30
,
40
],
[
31
,
41
],
[
32
,
42
],
[
0
,
0
],
[
0
,
0
]],
[[
30
,
40
],
[
31
,
41
],
[
32
,
42
],
[
33
,
43
],
[
0
,
0
]],
[[
30
,
40
],
[
31
,
41
],
[
32
,
42
],
[
33
,
43
],
[
34
,
44
]],
],
features
[
"conditioning_stack"
])
np
.
testing
.
assert_almost_equal
([
[[
1
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
1
],
[
1
],
[
0
],
[
0
],
[
0
]],
[[
1
],
[
1
],
[
1
],
[
0
],
[
0
]],
[[
1
],
[
1
],
[
1
],
[
1
],
[
0
]],
[[
1
],
[
1
],
[
1
],
[
1
],
[
1
]],
],
features
[
"weights"
])
features
=
sess
.
run
(
next_features
)
np
.
testing
.
assert_almost_equal
([
[[
10
],
[
11
],
[
12
],
[
13
],
[
14
],
[
15
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
14
],
[
15
],
[
16
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
14
],
[
15
],
[
16
],
[
17
]],
[[
10
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
],
features
[
"autoregressive_input"
])
np
.
testing
.
assert_almost_equal
([
[[
30
,
40
],
[
31
,
41
],
[
32
,
42
],
[
33
,
43
],
[
34
,
44
],
[
35
,
45
],
[
0
,
0
],
[
0
,
0
]],
[[
30
,
40
],
[
31
,
41
],
[
32
,
42
],
[
33
,
43
],
[
34
,
44
],
[
35
,
45
],
[
36
,
46
],
[
0
,
0
]],
[[
30
,
40
],
[
31
,
41
],
[
32
,
42
],
[
33
,
43
],
[
34
,
44
],
[
35
,
45
],
[
36
,
46
],
[
37
,
47
]],
[[
30
,
40
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
]],
[[
30
,
40
],
[
31
,
41
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
]],
],
features
[
"conditioning_stack"
])
np
.
testing
.
assert_almost_equal
([
[[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
0
],
[
0
]],
[[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
0
]],
[[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
1
]],
[[
1
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
1
],
[
1
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
],
features
[
"weights"
])
features
=
sess
.
run
(
next_features
)
np
.
testing
.
assert_almost_equal
([
[[
10
],
[
11
],
[
12
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
0
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
14
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
14
],
[
15
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
14
],
[
15
],
[
16
]],
],
features
[
"autoregressive_input"
])
np
.
testing
.
assert_almost_equal
([
[[
30
,
40
],
[
31
,
41
],
[
32
,
42
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
]],
[[
30
,
40
],
[
31
,
41
],
[
32
,
42
],
[
33
,
43
],
[
0
,
0
],
[
0
,
0
],
[
0
,
0
]],
[[
30
,
40
],
[
31
,
41
],
[
32
,
42
],
[
33
,
43
],
[
34
,
44
],
[
0
,
0
],
[
0
,
0
]],
[[
30
,
40
],
[
31
,
41
],
[
32
,
42
],
[
33
,
43
],
[
34
,
44
],
[
35
,
45
],
[
0
,
0
]],
[[
30
,
40
],
[
31
,
41
],
[
32
,
42
],
[
33
,
43
],
[
34
,
44
],
[
35
,
45
],
[
36
,
46
]
],
],
features
[
"conditioning_stack"
])
np
.
testing
.
assert_almost_equal
([
[[
1
],
[
1
],
[
1
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
1
],
[
1
],
[
1
],
[
1
],
[
0
],
[
0
],
[
0
]],
[[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
0
],
[
0
]],
[[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
0
]],
[[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
1
]],
],
features
[
"weights"
])
def
testTrainModeMaxLength
(
self
):
config_overrides
=
{
"max_length"
:
6
}
builder
=
TFRecordDataset
(
self
.
_file_pattern
,
tf
.
estimator
.
ModeKeys
.
TRAIN
,
config_overrides
=
config_overrides
)
next_features
=
builder
.
build
(
5
).
make_one_shot_iterator
().
get_next
()
self
.
assertItemsEqual
(
[
"autoregressive_input"
,
"conditioning_stack"
,
"weights"
],
next_features
.
keys
())
# Features have dynamic length but fixed batch size and input dimension.
next_features
[
"autoregressive_input"
].
shape
.
assert_is_compatible_with
(
[
5
,
None
,
1
])
next_features
[
"conditioning_stack"
].
shape
.
assert_is_compatible_with
(
[
5
,
None
,
1
])
next_features
[
"weights"
].
shape
.
assert_is_compatible_with
([
5
,
1
,
None
])
# Dataset repeats indefinitely.
with
self
.
test_session
()
as
sess
:
features
=
sess
.
run
(
next_features
)
np
.
testing
.
assert_almost_equal
([
[[
10
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
0
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
14
]],
],
features
[
"autoregressive_input"
])
np
.
testing
.
assert_almost_equal
([
[[
30
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
0
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
34
]],
],
features
[
"conditioning_stack"
])
np
.
testing
.
assert_almost_equal
([
[[
1
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
1
],
[
1
],
[
0
],
[
0
],
[
0
]],
[[
1
],
[
1
],
[
1
],
[
0
],
[
0
]],
[[
1
],
[
1
],
[
1
],
[
1
],
[
0
]],
[[
1
],
[
1
],
[
1
],
[
1
],
[
1
]],
],
features
[
"weights"
])
features
=
sess
.
run
(
next_features
)
np
.
testing
.
assert_almost_equal
([
[[
10
],
[
11
],
[
12
],
[
13
],
[
14
],
[
15
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
14
],
[
15
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
14
],
[
15
]],
[[
10
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
0
],
[
0
],
[
0
],
[
0
]],
],
features
[
"autoregressive_input"
])
np
.
testing
.
assert_almost_equal
([
[[
30
],
[
31
],
[
32
],
[
33
],
[
34
],
[
35
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
34
],
[
35
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
34
],
[
35
]],
[[
30
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
0
],
[
0
],
[
0
],
[
0
]],
],
features
[
"conditioning_stack"
])
np
.
testing
.
assert_almost_equal
([
[[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
1
]],
[[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
1
]],
[[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
1
]],
[[
1
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
1
],
[
1
],
[
0
],
[
0
],
[
0
],
[
0
]],
],
features
[
"weights"
])
features
=
sess
.
run
(
next_features
)
np
.
testing
.
assert_almost_equal
([
[[
10
],
[
11
],
[
12
],
[
0
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
14
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
14
],
[
15
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
14
],
[
15
]],
],
features
[
"autoregressive_input"
])
np
.
testing
.
assert_almost_equal
([
[[
30
],
[
31
],
[
32
],
[
0
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
34
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
34
],
[
35
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
34
],
[
35
]],
],
features
[
"conditioning_stack"
])
np
.
testing
.
assert_almost_equal
([
[[
1
],
[
1
],
[
1
],
[
0
],
[
0
],
[
0
]],
[[
1
],
[
1
],
[
1
],
[
1
],
[
0
],
[
0
]],
[[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
0
]],
[[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
1
]],
[[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
1
]],
],
features
[
"weights"
])
def
testTrainModeTPU
(
self
):
config_overrides
=
{
"max_length"
:
6
}
builder
=
TFRecordDataset
(
self
.
_file_pattern
,
tf
.
estimator
.
ModeKeys
.
TRAIN
,
config_overrides
=
config_overrides
,
use_tpu
=
True
)
next_features
=
builder
.
build
(
5
).
make_one_shot_iterator
().
get_next
()
self
.
assertItemsEqual
(
[
"autoregressive_input"
,
"conditioning_stack"
,
"weights"
],
next_features
.
keys
())
# Features have fixed shape.
self
.
assertEqual
([
5
,
6
,
1
],
next_features
[
"autoregressive_input"
].
shape
)
self
.
assertEqual
([
5
,
6
,
1
],
next_features
[
"conditioning_stack"
].
shape
)
self
.
assertEqual
([
5
,
6
,
1
],
next_features
[
"weights"
].
shape
)
# Dataset repeats indefinitely.
with
self
.
test_session
()
as
sess
:
features
=
sess
.
run
(
next_features
)
np
.
testing
.
assert_almost_equal
([
[[
10
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
0
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
14
],
[
0
]],
],
features
[
"autoregressive_input"
])
np
.
testing
.
assert_almost_equal
([
[[
30
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
0
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
34
],
[
0
]],
],
features
[
"conditioning_stack"
])
np
.
testing
.
assert_almost_equal
([
[[
1
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
1
],
[
1
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
1
],
[
1
],
[
1
],
[
0
],
[
0
],
[
0
]],
[[
1
],
[
1
],
[
1
],
[
1
],
[
0
],
[
0
]],
[[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
0
]],
],
features
[
"weights"
])
features
=
sess
.
run
(
next_features
)
np
.
testing
.
assert_almost_equal
([
[[
10
],
[
11
],
[
12
],
[
13
],
[
14
],
[
15
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
14
],
[
15
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
14
],
[
15
]],
[[
10
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
0
],
[
0
],
[
0
],
[
0
]],
],
features
[
"autoregressive_input"
])
np
.
testing
.
assert_almost_equal
([
[[
30
],
[
31
],
[
32
],
[
33
],
[
34
],
[
35
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
34
],
[
35
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
34
],
[
35
]],
[[
30
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
0
],
[
0
],
[
0
],
[
0
]],
],
features
[
"conditioning_stack"
])
np
.
testing
.
assert_almost_equal
([
[[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
1
]],
[[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
1
]],
[[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
1
]],
[[
1
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
1
],
[
1
],
[
0
],
[
0
],
[
0
],
[
0
]],
],
features
[
"weights"
])
features
=
sess
.
run
(
next_features
)
np
.
testing
.
assert_almost_equal
([
[[
10
],
[
11
],
[
12
],
[
0
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
14
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
14
],
[
15
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
14
],
[
15
]],
],
features
[
"autoregressive_input"
])
np
.
testing
.
assert_almost_equal
([
[[
30
],
[
31
],
[
32
],
[
0
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
34
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
34
],
[
35
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
34
],
[
35
]],
],
features
[
"conditioning_stack"
])
np
.
testing
.
assert_almost_equal
([
[[
1
],
[
1
],
[
1
],
[
0
],
[
0
],
[
0
]],
[[
1
],
[
1
],
[
1
],
[
1
],
[
0
],
[
0
]],
[[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
0
]],
[[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
1
]],
[[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
1
]],
],
features
[
"weights"
])
def
testEvalMode
(
self
):
builder
=
TFRecordDataset
(
self
.
_file_pattern
,
tf
.
estimator
.
ModeKeys
.
EVAL
)
next_features
=
builder
.
build
(
5
).
make_one_shot_iterator
().
get_next
()
self
.
assertItemsEqual
(
[
"autoregressive_input"
,
"conditioning_stack"
,
"weights"
],
next_features
.
keys
())
# Features have dynamic length but fixed batch size and input dimension.
next_features
[
"autoregressive_input"
].
shape
.
assert_is_compatible_with
(
[
5
,
None
,
1
])
next_features
[
"conditioning_stack"
].
shape
.
assert_is_compatible_with
(
[
5
,
None
,
1
])
next_features
[
"weights"
].
shape
.
assert_is_compatible_with
([
5
,
1
,
None
])
with
self
.
test_session
()
as
sess
:
features
=
sess
.
run
(
next_features
)
np
.
testing
.
assert_almost_equal
([
[[
10
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
0
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
14
]],
],
features
[
"autoregressive_input"
])
np
.
testing
.
assert_almost_equal
([
[[
30
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
0
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
34
]],
],
features
[
"conditioning_stack"
])
np
.
testing
.
assert_almost_equal
([
[[
1
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
1
],
[
1
],
[
0
],
[
0
],
[
0
]],
[[
1
],
[
1
],
[
1
],
[
0
],
[
0
]],
[[
1
],
[
1
],
[
1
],
[
1
],
[
0
]],
[[
1
],
[
1
],
[
1
],
[
1
],
[
1
]],
],
features
[
"weights"
])
# Partial batch.
features
=
sess
.
run
(
next_features
)
np
.
testing
.
assert_almost_equal
([
[[
10
],
[
11
],
[
12
],
[
13
],
[
14
],
[
15
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
14
],
[
15
],
[
16
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
14
],
[
15
],
[
16
],
[
17
]],
],
features
[
"autoregressive_input"
])
np
.
testing
.
assert_almost_equal
([
[[
30
],
[
31
],
[
32
],
[
33
],
[
34
],
[
35
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
34
],
[
35
],
[
36
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
34
],
[
35
],
[
36
],
[
37
]],
],
features
[
"conditioning_stack"
])
np
.
testing
.
assert_almost_equal
([
[[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
0
],
[
0
]],
[[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
0
]],
[[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
1
]],
],
features
[
"weights"
])
with
self
.
assertRaises
(
tf
.
errors
.
OutOfRangeError
):
sess
.
run
(
next_features
)
def
testEvalModeTPU
(
self
):
config_overrides
=
{
"max_length"
:
6
}
builder
=
TFRecordDataset
(
self
.
_file_pattern
,
tf
.
estimator
.
ModeKeys
.
EVAL
,
config_overrides
=
config_overrides
,
use_tpu
=
True
)
next_features
=
builder
.
build
(
5
).
make_one_shot_iterator
().
get_next
()
self
.
assertItemsEqual
(
[
"autoregressive_input"
,
"conditioning_stack"
,
"weights"
],
next_features
.
keys
())
# Features have fixed shape.
self
.
assertEqual
([
5
,
6
,
1
],
next_features
[
"autoregressive_input"
].
shape
)
self
.
assertEqual
([
5
,
6
,
1
],
next_features
[
"conditioning_stack"
].
shape
)
self
.
assertEqual
([
5
,
6
,
1
],
next_features
[
"weights"
].
shape
)
with
self
.
test_session
()
as
sess
:
features
=
sess
.
run
(
next_features
)
np
.
testing
.
assert_almost_equal
([
[[
10
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
0
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
0
],
[
0
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
14
],
[
0
]],
],
features
[
"autoregressive_input"
])
np
.
testing
.
assert_almost_equal
([
[[
30
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
0
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
0
],
[
0
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
34
],
[
0
]],
],
features
[
"conditioning_stack"
])
np
.
testing
.
assert_almost_equal
([
[[
1
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
1
],
[
1
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
1
],
[
1
],
[
1
],
[
0
],
[
0
],
[
0
]],
[[
1
],
[
1
],
[
1
],
[
1
],
[
0
],
[
0
]],
[[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
0
]],
],
features
[
"weights"
])
# Partial batch, padded.
features
=
sess
.
run
(
next_features
)
np
.
testing
.
assert_almost_equal
([
[[
10
],
[
11
],
[
12
],
[
13
],
[
14
],
[
15
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
14
],
[
15
]],
[[
10
],
[
11
],
[
12
],
[
13
],
[
14
],
[
15
]],
[[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
],
features
[
"autoregressive_input"
])
np
.
testing
.
assert_almost_equal
([
[[
30
],
[
31
],
[
32
],
[
33
],
[
34
],
[
35
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
34
],
[
35
]],
[[
30
],
[
31
],
[
32
],
[
33
],
[
34
],
[
35
]],
[[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
],
features
[
"conditioning_stack"
])
np
.
testing
.
assert_almost_equal
([
[[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
1
]],
[[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
1
]],
[[
1
],
[
1
],
[
1
],
[
1
],
[
1
],
[
1
]],
[[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
[[
0
],
[
0
],
[
0
],
[
0
],
[
0
],
[
0
]],
],
features
[
"weights"
])
with
self
.
assertRaises
(
tf
.
errors
.
OutOfRangeError
):
sess
.
run
(
next_features
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
research/astronet/astrowavenet/data/synthetic_transit_maker.py
deleted
100644 → 0
View file @
93e0022d
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generates synthetic light curves with periodic transit-like dips.
See class docstring below for more information.
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
class
SyntheticTransitMaker
(
object
):
"""Generates synthetic light curves with periodic transit-like dips.
These light curves are generated by thresholding noisy sine waves. Each time
random_light_curve is called, a thresholded sine wave is generated by sampling
parameters uniformly from the ranges specified below.
Attributes:
period_range: A tuple of positive values specifying the range of periods the
sine waves may take.
amplitude_range: A tuple of positive values specifying the range of
amplitudes the sine waves may take.
threshold_ratio_range: A tuple of values in [0, 1) specifying the range of
thresholds as a ratio of the sine wave amplitude.
phase_range: Tuple of values specifying the range of phases the sine wave
may take as a ratio of the sampled period. E.g. a sampled phase of 0.5
would translate the sine wave by half of the period. The most common
reason to override this would be to generate light curves
deterministically (with e.g. (0,0)).
noise_sd_range: A tuple of values in [0, 1) specifying the range of standard
deviations for the Gaussian noise applied to the sine wave.
"""
def
__init__
(
self
,
period_range
=
(
0.5
,
4
),
amplitude_range
=
(
1
,
1
),
threshold_ratio_range
=
(
0
,
0.99
),
phase_range
=
(
0
,
1
),
noise_sd_range
=
(
0.1
,
0.1
)):
if
threshold_ratio_range
[
0
]
<
0
or
threshold_ratio_range
[
1
]
>=
1
:
raise
ValueError
(
"Threshold ratio range must be in [0, 1). Got: {}."
.
format
(
threshold_ratio_range
))
if
amplitude_range
[
0
]
<=
0
:
raise
ValueError
(
"Amplitude range must only contain positive numbers. Got: {}."
.
format
(
amplitude_range
))
if
period_range
[
0
]
<=
0
:
raise
ValueError
(
"Period range must only contain positive numbers. Got: {}."
.
format
(
period_range
))
if
noise_sd_range
[
0
]
<
0
:
raise
ValueError
(
"Noise standard deviation range must be nonnegative. Got: {}."
.
format
(
noise_sd_range
))
for
(
start
,
end
),
name
in
[(
period_range
,
"period"
),
(
amplitude_range
,
"amplitude"
),
(
threshold_ratio_range
,
"threshold ratio"
),
(
phase_range
,
"phase range"
),
(
noise_sd_range
,
"noise standard deviation"
)]:
if
end
<
start
:
raise
ValueError
(
"End of {} range may not be less than start. Got: ({}, {})"
.
format
(
name
,
start
,
end
))
self
.
period_range
=
period_range
self
.
amplitude_range
=
amplitude_range
self
.
threshold_ratio_range
=
threshold_ratio_range
self
.
phase_range
=
phase_range
self
.
noise_sd_range
=
noise_sd_range
def
random_light_curve
(
self
,
time
,
mask_prob
=
0
):
"""Samples parameters and generates a light curve.
Args:
time: np.array, x-values to sample from the thresholded sine wave.
mask_prob: value in [0,1], probability an individual datapoint is set to
zero
Returns:
flux: np.array, values of the masked sampled light curve corresponding to
the provided time array.
mask: np.array of ones and zeros, with zeros indicating masking at the
respective position on the flux array.
"""
period
=
np
.
random
.
uniform
(
*
self
.
period_range
)
phase
=
np
.
random
.
uniform
(
*
self
.
phase_range
)
*
period
amplitude
=
np
.
random
.
uniform
(
*
self
.
amplitude_range
)
threshold
=
np
.
random
.
uniform
(
*
self
.
threshold_ratio_range
)
*
amplitude
sin_wave
=
np
.
sin
(
time
/
period
-
phase
)
*
amplitude
flux
=
np
.
minimum
(
sin_wave
,
-
threshold
)
+
threshold
noise_sd
=
np
.
random
.
uniform
(
*
self
.
noise_sd_range
)
noise
=
np
.
random
.
normal
(
scale
=
noise_sd
,
size
=
(
len
(
time
),))
flux
+=
noise
# Array of ones and zeros, where zeros indicate masking.
mask
=
np
.
random
.
random
(
len
(
time
))
>
mask_prob
mask
=
mask
.
astype
(
np
.
float
)
return
flux
*
mask
,
mask
def
random_light_curve_generator
(
self
,
time
,
mask_prob
=
0
):
"""Returns a generator function yielding random light curves.
Args:
time: An np.array of x-values to sample from the thresholded sine wave.
mask_prob: Value in [0,1], probability an individual datapoint is set to
zero.
Returns:
A generator yielding random light curves.
"""
def
generator_fn
():
while
True
:
yield
self
.
random_light_curve
(
time
,
mask_prob
)
return
generator_fn
Prev
1
2
3
4
5
6
7
…
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment