Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
edea2b67
Commit
edea2b67
authored
May 11, 2018
by
Terry Koo
Browse files
Remove runtime because reasons.
parent
a4bb31d0
Changes
291
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
6703 deletions
+0
-6703
research/syntaxnet/dragnn/runtime/BUILD
research/syntaxnet/dragnn/runtime/BUILD
+0
-2296
research/syntaxnet/dragnn/runtime/activation_functions.h
research/syntaxnet/dragnn/runtime/activation_functions.h
+0
-62
research/syntaxnet/dragnn/runtime/activation_functions_test.cc
...rch/syntaxnet/dragnn/runtime/activation_functions_test.cc
+0
-56
research/syntaxnet/dragnn/runtime/alignment.h
research/syntaxnet/dragnn/runtime/alignment.h
+0
-462
research/syntaxnet/dragnn/runtime/alignment_test.cc
research/syntaxnet/dragnn/runtime/alignment_test.cc
+0
-760
research/syntaxnet/dragnn/runtime/array_variable_store.cc
research/syntaxnet/dragnn/runtime/array_variable_store.cc
+0
-181
research/syntaxnet/dragnn/runtime/array_variable_store.h
research/syntaxnet/dragnn/runtime/array_variable_store.h
+0
-86
research/syntaxnet/dragnn/runtime/array_variable_store_builder.cc
.../syntaxnet/dragnn/runtime/array_variable_store_builder.cc
+0
-91
research/syntaxnet/dragnn/runtime/array_variable_store_builder.h
...h/syntaxnet/dragnn/runtime/array_variable_store_builder.h
+0
-52
research/syntaxnet/dragnn/runtime/array_variable_store_builder_test.cc
...axnet/dragnn/runtime/array_variable_store_builder_test.cc
+0
-141
research/syntaxnet/dragnn/runtime/array_variable_store_test.cc
...rch/syntaxnet/dragnn/runtime/array_variable_store_test.cc
+0
-384
research/syntaxnet/dragnn/runtime/attributes.cc
research/syntaxnet/dragnn/runtime/attributes.cc
+0
-117
research/syntaxnet/dragnn/runtime/attributes.h
research/syntaxnet/dragnn/runtime/attributes.h
+0
-204
research/syntaxnet/dragnn/runtime/attributes_test.cc
research/syntaxnet/dragnn/runtime/attributes_test.cc
+0
-260
research/syntaxnet/dragnn/runtime/biaffine_digraph_component.cc
...ch/syntaxnet/dragnn/runtime/biaffine_digraph_component.cc
+0
-259
research/syntaxnet/dragnn/runtime/biaffine_digraph_component_test.cc
...ntaxnet/dragnn/runtime/biaffine_digraph_component_test.cc
+0
-345
research/syntaxnet/dragnn/runtime/bulk_dynamic_component.cc
research/syntaxnet/dragnn/runtime/bulk_dynamic_component.cc
+0
-217
research/syntaxnet/dragnn/runtime/bulk_dynamic_component_test.cc
...h/syntaxnet/dragnn/runtime/bulk_dynamic_component_test.cc
+0
-276
research/syntaxnet/dragnn/runtime/bulk_feed_forward_network.cc
...rch/syntaxnet/dragnn/runtime/bulk_feed_forward_network.cc
+0
-90
research/syntaxnet/dragnn/runtime/bulk_feed_forward_network_test.cc
...yntaxnet/dragnn/runtime/bulk_feed_forward_network_test.cc
+0
-364
No files found.
Too many changes to show.
To preserve performance only
291 of 291+
files are displayed.
Plain diff
Email patch
research/syntaxnet/dragnn/runtime/BUILD
deleted
100644 → 0
View file @
a4bb31d0
package
(
default_visibility
=
[
"//visibility:public"
],
)
load
(
"@org_tensorflow//tensorflow:tensorflow.bzl"
,
"if_linux_x86_64"
,
)
load
(
"//dragnn/runtime:multiarch.bzl"
,
"dragnn_cc_multiarch_binary"
,
"dragnn_cc_multiarch_library"
,
"dragnn_cc_multiarch_test"
,
)
FAST_MATH_COPTS
=
if_linux_x86_64
([
"-O3"
,
"-msse4.2"
,
"-ffast-math"
,
"-ftree-vectorize"
,
])
filegroup
(
name
=
"test_rnn_tagger"
,
srcs
=
glob
([
"testdata/rnn_tagger/**"
]),
)
cc_library
(
name
=
"alignment"
,
hdrs
=
[
"alignment.h"
],
deps
=
[
"@org_tensorflow//tensorflow/core:lib"
,
],
)
cc_test
(
name
=
"alignment_test"
,
size
=
"small"
,
srcs
=
[
"alignment_test.cc"
],
deps
=
[
":alignment"
,
"//dragnn/core/test:generic"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"mmap"
,
srcs
=
[
"mmap.cc"
],
hdrs
=
[
"mmap.h"
],
deps
=
[
":alignment"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
)
cc_test
(
name
=
"mmap_test"
,
size
=
"small"
,
srcs
=
[
"mmap_test.cc"
],
data
=
[
"testdata/empty_file"
,
"testdata/ten_bytes"
,
],
deps
=
[
":mmap"
,
"//dragnn/core/test:generic"
,
"//syntaxnet:base"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"operands"
,
srcs
=
[
"operands.cc"
],
hdrs
=
[
"operands.h"
],
deps
=
[
":alignment"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
)
cc_test
(
name
=
"operands_test"
,
size
=
"small"
,
srcs
=
[
"operands_test.cc"
],
deps
=
[
":alignment"
,
":operands"
,
"//dragnn/runtime/math:types"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"variable_store"
,
hdrs
=
[
"variable_store.h"
],
deps
=
[
":alignment"
,
"//dragnn/protos:runtime_proto_cc"
,
"//dragnn/runtime/math:types"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
)
cc_test
(
name
=
"variable_store_test"
,
size
=
"small"
,
srcs
=
[
"variable_store_test.cc"
],
deps
=
[
":variable_store"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:runtime_proto_cc"
,
"//dragnn/runtime/test:fake_variable_store"
,
"//dragnn/runtime/test:helpers"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"trained_model"
,
srcs
=
[
"trained_model.cc"
],
hdrs
=
[
"trained_model.h"
],
deps
=
[
"//dragnn/core:dragnn_bulk_ops_cc"
,
"//dragnn/core:dragnn_ops_cc"
,
"//syntaxnet:base"
,
"//syntaxnet:parser_ops_cc"
,
"@org_tensorflow//tensorflow/cc/saved_model:loader"
,
"@org_tensorflow//tensorflow/cc/saved_model:tag_constants"
,
"@org_tensorflow//tensorflow/core:core_cpu"
,
"@org_tensorflow//tensorflow/core:framework_headers_lib"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:protos_all_cc"
,
],
)
cc_test
(
name
=
"trained_model_test"
,
size
=
"small"
,
timeout
=
"moderate"
,
srcs
=
[
"trained_model_test.cc"
],
data
=
[
":test_rnn_tagger"
],
deps
=
[
":trained_model"
,
"//dragnn/components/syntaxnet:syntaxnet_component"
,
"//dragnn/core/test:generic"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"trained_model_variable_store"
,
srcs
=
[
"trained_model_variable_store.cc"
],
hdrs
=
[
"trained_model_variable_store.h"
],
deps
=
[
":alignment"
,
":trained_model"
,
":variable_store"
,
"//dragnn/protos:runtime_proto_cc"
,
"//dragnn/runtime/math:types"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:core_cpu"
,
"@org_tensorflow//tensorflow/core:framework_headers_lib"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:protos_all_cc"
,
"@org_tensorflow//tensorflow/core:tensorflow"
,
],
)
cc_test
(
name
=
"trained_model_variable_store_test"
,
size
=
"small"
,
timeout
=
"moderate"
,
srcs
=
[
"trained_model_variable_store_test.cc"
],
data
=
[
":test_rnn_tagger"
],
shard_count
=
13
,
deps
=
[
":trained_model_variable_store"
,
"//dragnn/components/syntaxnet:syntaxnet_component"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:runtime_proto_cc"
,
"//dragnn/runtime/math:avx_vector_array"
,
"//dragnn/runtime/math:float16_types"
,
"//syntaxnet:base"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"variable_store_wrappers"
,
srcs
=
[
"variable_store_wrappers.cc"
],
hdrs
=
[
"variable_store_wrappers.h"
],
deps
=
[
":alignment"
,
":flexible_matrix_kernel"
,
":variable_store"
,
"//dragnn/protos:runtime_proto_cc"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
)
cc_test
(
name
=
"variable_store_wrappers_test"
,
size
=
"small"
,
srcs
=
[
"variable_store_wrappers_test.cc"
],
deps
=
[
":flexible_matrix_kernel"
,
":variable_store_wrappers"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:runtime_proto_cc"
,
"//dragnn/runtime/math:transformations"
,
"//dragnn/runtime/math:types"
,
"//dragnn/runtime/test:fake_variable_store"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"array_variable_store"
,
srcs
=
[
"array_variable_store.cc"
],
hdrs
=
[
"array_variable_store.h"
],
deps
=
[
":alignment"
,
":variable_store"
,
"//dragnn/protos:runtime_proto_cc"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
)
cc_test
(
name
=
"array_variable_store_test"
,
size
=
"small"
,
srcs
=
[
"array_variable_store_test.cc"
],
data
=
[
"testdata/array_variable_store_data"
,
"testdata/array_variable_store_spec"
,
"testdata/empty_file"
,
],
deps
=
[
":alignment"
,
":array_variable_store"
,
":file_array_variable_store"
,
":mmap_array_variable_store"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:runtime_proto_cc"
,
"//dragnn/runtime/math:types"
,
"//dragnn/runtime/test:helpers"
,
"//syntaxnet:base"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"array_variable_store_builder"
,
srcs
=
[
"array_variable_store_builder.cc"
],
hdrs
=
[
"array_variable_store_builder.h"
],
deps
=
[
":alignment"
,
":array_variable_store"
,
":variable_store_wrappers"
,
"//dragnn/protos:runtime_proto_cc"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
)
cc_test
(
name
=
"array_variable_store_builder_test"
,
size
=
"small"
,
srcs
=
[
"array_variable_store_builder_test.cc"
],
data
=
[
"testdata/array_variable_store_data"
,
"testdata/array_variable_store_spec"
,
],
deps
=
[
":alignment"
,
":array_variable_store_builder"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:runtime_proto_cc"
,
"//dragnn/runtime/test:helpers"
,
"//syntaxnet:base"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
# Tested in array_variable_store_test.
cc_library
(
name
=
"file_array_variable_store"
,
srcs
=
[
"file_array_variable_store.cc"
],
hdrs
=
[
"file_array_variable_store.h"
],
deps
=
[
":alignment"
,
":array_variable_store"
,
"//dragnn/protos:runtime_proto_cc"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
)
# Tested in array_variable_store_test.
cc_library
(
name
=
"mmap_array_variable_store"
,
srcs
=
[
"mmap_array_variable_store.cc"
],
hdrs
=
[
"mmap_array_variable_store.h"
],
deps
=
[
":array_variable_store"
,
":mmap"
,
"//dragnn/protos:runtime_proto_cc"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
)
cc_library
(
name
=
"network_states"
,
srcs
=
[
"network_states.cc"
],
hdrs
=
[
"network_states.h"
],
deps
=
[
":alignment"
,
":operands"
,
"//dragnn/runtime/math:types"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
)
cc_test
(
name
=
"network_states_test"
,
size
=
"small"
,
srcs
=
[
"network_states_test.cc"
],
deps
=
[
":alignment"
,
":network_states"
,
"//dragnn/core/test:generic"
,
"//dragnn/runtime/math:types"
,
"//syntaxnet:base"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"extensions"
,
srcs
=
[
"extensions.cc"
],
hdrs
=
[
"extensions.h"
],
deps
=
[
"@org_tensorflow//tensorflow/core:lib"
,
],
)
cc_test
(
name
=
"extensions_test"
,
size
=
"small"
,
srcs
=
[
"extensions_test.cc"
],
deps
=
[
":extensions"
,
"//syntaxnet:base"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
dragnn_cc_multiarch_library
(
name
=
"linked_embeddings"
,
srcs
=
[
"linked_embeddings.cc"
],
hdrs
=
[
"linked_embeddings.h"
],
copts
=
FAST_MATH_COPTS
,
opts_self
=
True
,
deps
=
[
":alignment"
,
":flexible_matrix_kernel"
,
":network_states"
,
":variable_store"
,
"//dragnn/core:compute_session"
,
"//dragnn/protos:data_proto_cc"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime/math:arithmetic"
,
"//dragnn/runtime/math:types"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
)
dragnn_cc_multiarch_test
(
name
=
"linked_embeddings_test"
,
size
=
"small"
,
srcs
=
[
"linked_embeddings_test.cc"
],
deps
=
[
":linked_embeddings"
,
":network_states"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime/test:network_test_base"
,
"//syntaxnet:base"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
dragnn_cc_multiarch_library
(
name
=
"fixed_embeddings"
,
srcs
=
[
"fixed_embeddings.cc"
],
hdrs
=
[
"fixed_embeddings.h"
],
copts
=
FAST_MATH_COPTS
,
opts_self
=
True
,
deps
=
[
":alignment"
,
":network_states"
,
":variable_store"
,
"//dragnn/core:compute_session"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime/math:arithmetic"
,
"//dragnn/runtime/math:types"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
)
dragnn_cc_multiarch_test
(
name
=
"fixed_embeddings_test"
,
size
=
"small"
,
srcs
=
[
"fixed_embeddings_test.cc"
],
deps
=
[
":fixed_embeddings"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime/test:network_test_base"
,
"//syntaxnet:base"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"type_keyed_set"
,
hdrs
=
[
"type_keyed_set.h"
],
)
cc_test
(
name
=
"type_keyed_set_test"
,
size
=
"small"
,
srcs
=
[
"type_keyed_set_test.cc"
],
deps
=
[
":type_keyed_set"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"session_state"
,
hdrs
=
[
"session_state.h"
],
deps
=
[
":extensions"
,
":network_states"
,
],
)
cc_library
(
name
=
"session_state_pool"
,
srcs
=
[
"session_state_pool.cc"
],
hdrs
=
[
"session_state_pool.h"
],
deps
=
[
":session_state"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
)
cc_test
(
name
=
"session_state_pool_test"
,
size
=
"small"
,
srcs
=
[
"session_state_pool_test.cc"
],
deps
=
[
":session_state"
,
":session_state_pool"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
dragnn_cc_multiarch_library
(
name
=
"bulk_dynamic_component"
,
srcs
=
[
"bulk_dynamic_component.cc"
],
deps
=
[
":bulk_network_unit"
,
":component"
,
":extensions"
,
":network_states"
,
":network_unit_base"
,
":transition_system_traits"
,
":variable_store"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/protos:trace_proto_cc"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
alwayslink
=
1
,
)
dragnn_cc_multiarch_test
(
name
=
"bulk_dynamic_component_test"
,
srcs
=
[
"bulk_dynamic_component_test.cc"
],
deps
=
[
":bulk_dynamic_component"
,
":bulk_network_unit"
,
":component"
,
":extensions"
,
":network_states"
,
":variable_store"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime/test:network_test_base"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
dragnn_cc_multiarch_library
(
name
=
"sequence_bulk_dynamic_component"
,
srcs
=
[
"sequence_bulk_dynamic_component.cc"
],
deps
=
[
":bulk_network_unit"
,
":component"
,
":extensions"
,
":fixed_embeddings"
,
":linked_embeddings"
,
":network_states"
,
":sequence_model"
,
":session_state"
,
":variable_store"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/protos:trace_proto_cc"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
alwayslink
=
1
,
)
dragnn_cc_multiarch_test
(
name
=
"sequence_bulk_dynamic_component_test"
,
srcs
=
[
"sequence_bulk_dynamic_component_test.cc"
],
deps
=
[
":bulk_network_unit"
,
":component"
,
":extensions"
,
":network_states"
,
":sequence_backend"
,
":sequence_bulk_dynamic_component"
,
":sequence_extractor"
,
":sequence_linker"
,
":sequence_predictor"
,
":variable_store"
,
"//dragnn/core:input_batch_cache"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime/math:types"
,
"//dragnn/runtime/test:network_test_base"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"component"
,
srcs
=
[
"component.cc"
],
hdrs
=
[
"component.h"
],
deps
=
[
":extensions"
,
":network_states"
,
":session_state"
,
":variable_store"
,
"//dragnn/core:compute_session"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/protos:trace_proto_cc"
,
"//syntaxnet:base"
,
"//syntaxnet:registry"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
)
cc_test
(
name
=
"component_test"
,
size
=
"small"
,
srcs
=
[
"component_test.cc"
],
deps
=
[
":component"
,
":extensions"
,
":network_states"
,
":session_state"
,
":variable_store"
,
"//dragnn/core:compute_session"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/protos:trace_proto_cc"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
dragnn_cc_multiarch_library
(
name
=
"lstm_network_kernel"
,
srcs
=
[
"lstm_network_kernel.cc"
],
hdrs
=
[
"lstm_network_kernel.h"
],
copts
=
FAST_MATH_COPTS
,
opts_self
=
True
,
deps
=
[
":attributes"
,
":extensions"
,
":feed_forward_network_layer"
,
":flexible_matrix_kernel"
,
":network_states"
,
":session_state"
,
":transition_system_traits"
,
":variable_store"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime/lstm_cell:cell_function"
,
"//dragnn/runtime/math:avx_activation_functions"
,
"//dragnn/runtime/math:types"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
alwayslink
=
1
,
)
dragnn_cc_multiarch_test
(
name
=
"lstm_network_kernel_test"
,
srcs
=
[
"lstm_network_kernel_test.cc"
],
deps
=
[
":lstm_network_kernel"
,
":variable_store"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:runtime_proto_cc"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime/lstm_cell:cell_function"
,
"//dragnn/runtime/test:helpers"
,
"//dragnn/runtime/test:network_test_base"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
dragnn_cc_multiarch_library
(
name
=
"lstm_network"
,
srcs
=
[
"lstm_network.cc"
],
copts
=
FAST_MATH_COPTS
,
opts_self
=
True
,
deps
=
[
":extensions"
,
":lstm_network_kernel"
,
":network_unit"
,
":network_unit_base"
,
":transition_system_traits"
,
":variable_store"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime/lstm_cell:cell_function"
,
"//dragnn/runtime/math:avx_activation_functions"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
alwayslink
=
1
,
)
dragnn_cc_multiarch_test
(
name
=
"lstm_network_test"
,
srcs
=
[
"lstm_network_test.cc"
],
deps
=
[
":flexible_matrix_kernel"
,
":lstm_network"
,
":network_unit"
,
":variable_store"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:runtime_proto_cc"
,
"//dragnn/runtime/lstm_cell:cell_function"
,
"//dragnn/runtime/test:network_test_base"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
dragnn_cc_multiarch_library
(
name
=
"bulk_lstm_network"
,
srcs
=
[
"bulk_lstm_network.cc"
],
copts
=
FAST_MATH_COPTS
,
opts_self
=
True
,
deps
=
[
":bulk_network_unit"
,
":extensions"
,
":lstm_network_kernel"
,
":network_states"
,
":session_state"
,
":variable_store"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime/math:types"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
alwayslink
=
1
,
)
dragnn_cc_multiarch_test
(
name
=
"bulk_lstm_network_test"
,
srcs
=
[
"bulk_lstm_network_test.cc"
],
deps
=
[
":bulk_lstm_network"
,
":bulk_network_unit"
,
":flexible_matrix_kernel"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:runtime_proto_cc"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime/lstm_cell:cell_function"
,
"//dragnn/runtime/test:helpers"
,
"//dragnn/runtime/test:network_test_base"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"master"
,
srcs
=
[
"master.cc"
],
hdrs
=
[
"master.h"
],
deps
=
[
":component"
,
":extensions"
,
":network_states"
,
":session_state"
,
":session_state_pool"
,
":variable_store"
,
"//dragnn/core:compute_session"
,
"//dragnn/protos:runtime_proto_cc"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/protos:trace_proto_cc"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
)
cc_test
(
name
=
"master_test"
,
size
=
"small"
,
srcs
=
[
"master_test.cc"
],
deps
=
[
":alignment"
,
":component"
,
":extensions"
,
":master"
,
":network_states"
,
":session_state"
,
":variable_store"
,
"//dragnn/core/test:generic"
,
"//dragnn/core/test:mock_compute_session"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/protos:trace_proto_cc"
,
"//dragnn/runtime/test:fake_variable_store"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"network_unit"
,
srcs
=
[
"network_unit.cc"
],
hdrs
=
[
"network_unit.h"
],
deps
=
[
":extensions"
,
":network_states"
,
":session_state"
,
":variable_store"
,
"//dragnn/core:compute_session"
,
"//dragnn/protos:spec_proto_cc"
,
"//syntaxnet:base"
,
"//syntaxnet:registry"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
)
cc_test
(
name
=
"network_unit_test"
,
size
=
"small"
,
srcs
=
[
"network_unit_test.cc"
],
deps
=
[
":extensions"
,
":network_states"
,
":network_unit"
,
":session_state"
,
":variable_store"
,
"//dragnn/core:compute_session"
,
"//dragnn/protos:spec_proto_cc"
,
"//syntaxnet:base"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"bulk_network_unit"
,
srcs
=
[
"bulk_network_unit.cc"
],
hdrs
=
[
"bulk_network_unit.h"
],
deps
=
[
":extensions"
,
":network_states"
,
":network_unit"
,
":session_state"
,
":variable_store"
,
"//dragnn/core:compute_session"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime/math:types"
,
"//syntaxnet:base"
,
"//syntaxnet:registry"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
)
cc_test
(
name
=
"bulk_network_unit_test"
,
size
=
"small"
,
srcs
=
[
"bulk_network_unit_test.cc"
],
deps
=
[
":bulk_network_unit"
,
":extensions"
,
":network_states"
,
":session_state"
,
":variable_store"
,
"//dragnn/core:compute_session"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime/math:types"
,
"//syntaxnet:base"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"dynamic_component"
,
srcs
=
[
"dynamic_component.cc"
],
deps
=
[
":component"
,
":extensions"
,
":network_states"
,
":network_unit"
,
":session_state"
,
":transition_system_traits"
,
":variable_store"
,
"//dragnn/core:compute_session"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/protos:trace_proto_cc"
,
"//dragnn/runtime/math:types"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
alwayslink
=
1
,
)
cc_test
(
name
=
"dynamic_component_test"
,
size
=
"small"
,
srcs
=
[
"dynamic_component_test.cc"
],
deps
=
[
":component"
,
":dynamic_component"
,
":extensions"
,
":network_states"
,
":network_unit"
,
":session_state"
,
":variable_store"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime/math:types"
,
"//dragnn/runtime/test:network_test_base"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
dragnn_cc_multiarch_library
(
name
=
"network_unit_base"
,
srcs
=
[
"network_unit_base.cc"
],
hdrs
=
[
"network_unit_base.h"
],
deps
=
[
":extensions"
,
":fixed_embeddings"
,
":linked_embeddings"
,
":network_states"
,
":network_unit"
,
":session_state"
,
":variable_store"
,
"//dragnn/core:compute_session"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime/math:types"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
)
dragnn_cc_multiarch_test
(
name
=
"network_unit_base_test"
,
size
=
"small"
,
srcs
=
[
"network_unit_base_test.cc"
],
deps
=
[
":extensions"
,
":fixed_embeddings"
,
":linked_embeddings"
,
":network_states"
,
":network_unit_base"
,
":session_state"
,
":variable_store"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime/math:types"
,
"//dragnn/runtime/test:network_test_base"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"attributes"
,
srcs
=
[
"attributes.cc"
],
hdrs
=
[
"attributes.h"
],
deps
=
[
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
)
cc_test
(
name
=
"attributes_test"
,
size
=
"small"
,
srcs
=
[
"attributes_test.cc"
],
deps
=
[
":attributes"
,
"//dragnn/core/test:generic"
,
"//syntaxnet:base"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"activation_functions"
,
hdrs
=
[
"activation_functions.h"
],
deps
=
[
"//dragnn/runtime/math:arithmetic"
,
"//dragnn/runtime/math:types"
,
],
)
cc_test
(
name
=
"activation_functions_test"
,
size
=
"small"
,
srcs
=
[
"activation_functions_test.cc"
],
deps
=
[
":activation_functions"
,
"//dragnn/runtime/math:types"
,
"//dragnn/runtime/test:helpers"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"flexible_matrix_kernel"
,
srcs
=
[
"flexible_matrix_kernel.cc"
],
hdrs
=
[
"flexible_matrix_kernel.h"
],
deps
=
[
":alignment"
,
":variable_store"
,
"//dragnn/runtime/math:arithmetic"
,
"//dragnn/runtime/math:avx_vector_array"
,
"//dragnn/runtime/math:sgemvv"
,
"//dragnn/runtime/math:types"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
)
dragnn_cc_multiarch_test
(
name
=
"flexible_matrix_kernel_test"
,
srcs
=
[
"flexible_matrix_kernel_test.cc"
],
copts
=
FAST_MATH_COPTS
,
deps
=
[
":flexible_matrix_kernel"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:runtime_proto_cc"
,
"//dragnn/runtime/math:transformations"
,
"//dragnn/runtime/test:fake_variable_store"
,
"//dragnn/runtime/test:helpers"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
dragnn_cc_multiarch_library
(
name
=
"feed_forward_network_layer"
,
srcs
=
[
"feed_forward_network_layer.cc"
],
hdrs
=
[
"feed_forward_network_layer.h"
],
copts
=
FAST_MATH_COPTS
,
opts_self
=
True
,
deps
=
[
":activation_functions"
,
":flexible_matrix_kernel"
,
":network_states"
,
":variable_store"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime/math:types"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
alwayslink
=
1
,
)
dragnn_cc_multiarch_test
(
name
=
"feed_forward_network_layer_test"
,
size
=
"small"
,
srcs
=
[
"feed_forward_network_layer_test.cc"
],
deps
=
[
":activation_functions"
,
":feed_forward_network_layer"
,
":flexible_matrix_kernel"
,
":network_states"
,
":variable_store"
,
"//dragnn/core/test:generic"
,
"//dragnn/runtime/math:types"
,
"//dragnn/runtime/test:helpers"
,
"//dragnn/runtime/test:network_test_base"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
dragnn_cc_multiarch_library
(
name
=
"feed_forward_network_kernel"
,
srcs
=
[
"feed_forward_network_kernel.cc"
],
hdrs
=
[
"feed_forward_network_kernel.h"
],
copts
=
FAST_MATH_COPTS
,
opts_self
=
True
,
deps
=
[
":activation_functions"
,
":attributes"
,
":feed_forward_network_layer"
,
":flexible_matrix_kernel"
,
":network_states"
,
":transition_system_traits"
,
":variable_store"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime/math:types"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
alwayslink
=
1
,
)
dragnn_cc_multiarch_test
(
name
=
"feed_forward_network_kernel_test"
,
size
=
"small"
,
srcs
=
[
"feed_forward_network_kernel_test.cc"
],
deps
=
[
":activation_functions"
,
":feed_forward_network_kernel"
,
":flexible_matrix_kernel"
,
":network_states"
,
":variable_store"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime/math:types"
,
"//dragnn/runtime/test:network_test_base"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
dragnn_cc_multiarch_library
(
name
=
"feed_forward_network"
,
srcs
=
[
"feed_forward_network.cc"
],
copts
=
FAST_MATH_COPTS
,
opts_self
=
True
,
deps
=
[
":extensions"
,
":feed_forward_network_kernel"
,
":feed_forward_network_layer"
,
":network_states"
,
":network_unit"
,
":network_unit_base"
,
":session_state"
,
":variable_store"
,
"//dragnn/core:compute_session"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime/math:types"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
alwayslink
=
1
,
)
dragnn_cc_multiarch_test
(
name
=
"feed_forward_network_test"
,
size
=
"small"
,
srcs
=
[
"feed_forward_network_test.cc"
],
deps
=
[
":dynamic_component"
,
":feed_forward_network"
,
":flexible_matrix_kernel"
,
":network_states"
,
":network_unit"
,
":variable_store"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime/math:types"
,
"//dragnn/runtime/test:network_test_base"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
dragnn_cc_multiarch_library
(
name
=
"bulk_feed_forward_network"
,
srcs
=
[
"bulk_feed_forward_network.cc"
],
copts
=
FAST_MATH_COPTS
,
opts_self
=
True
,
deps
=
[
":bulk_network_unit"
,
":extensions"
,
":feed_forward_network_kernel"
,
":feed_forward_network_layer"
,
":network_states"
,
":session_state"
,
":variable_store"
,
"//dragnn/core:compute_session"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime/math:types"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
alwayslink
=
1
,
)
dragnn_cc_multiarch_test
(
name
=
"bulk_feed_forward_network_test"
,
size
=
"small"
,
srcs
=
[
"bulk_feed_forward_network_test.cc"
],
deps
=
[
":bulk_feed_forward_network"
,
":bulk_network_unit"
,
":dynamic_component"
,
":flexible_matrix_kernel"
,
":network_states"
,
":variable_store"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime/math:types"
,
"//dragnn/runtime/test:network_test_base"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"conversion"
,
srcs
=
[
"conversion.cc"
],
hdrs
=
[
"conversion.h"
],
deps
=
[
":array_variable_store_builder"
,
":master"
,
":trained_model_variable_store"
,
":variable_store"
,
":variable_store_wrappers"
,
"//dragnn/protos:runtime_proto_cc"
,
"//dragnn/protos:spec_proto_cc"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
)
dragnn_cc_multiarch_test
(
name
=
"conversion_test"
,
size
=
"small"
,
timeout
=
"moderate"
,
srcs
=
[
"conversion_test.cc"
],
data
=
[
"testdata/conversion_output_variables_data"
,
"testdata/conversion_output_variables_spec"
,
":test_rnn_tagger"
,
],
shard_count
=
6
,
deps
=
[
":conversion"
,
":dynamic_component"
,
":feed_forward_network"
,
":lstm_network"
,
"//dragnn/components/syntaxnet:syntaxnet_component"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:runtime_proto_cc"
,
"//syntaxnet:base"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"component_transformation"
,
srcs
=
[
"component_transformation.cc"
],
hdrs
=
[
"component_transformation.h"
],
deps
=
[
":component"
,
"//dragnn/protos:spec_proto_cc"
,
"//syntaxnet:base"
,
"//syntaxnet:registry"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
)
cc_test
(
name
=
"component_transformation_test"
,
size
=
"small"
,
srcs
=
[
"component_transformation_test.cc"
],
deps
=
[
":component_transformation"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:spec_proto_cc"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"fml_parsing"
,
srcs
=
[
"fml_parsing.cc"
],
hdrs
=
[
"fml_parsing.h"
],
deps
=
[
":attributes"
,
"//syntaxnet:base"
,
"//syntaxnet:feature_extractor_proto_cc"
,
"//syntaxnet:fml_parser"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
)
cc_test
(
name
=
"fml_parsing_test"
,
size
=
"small"
,
srcs
=
[
"fml_parsing_test.cc"
],
deps
=
[
":fml_parsing"
,
"//dragnn/core/test:generic"
,
"//syntaxnet:base"
,
"//syntaxnet:feature_extractor_proto_cc"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"term_map_utils"
,
srcs
=
[
"term_map_utils.cc"
],
hdrs
=
[
"term_map_utils.h"
],
deps
=
[
":fml_parsing"
,
"//dragnn/protos:spec_proto_cc"
,
"//syntaxnet:base"
,
"//syntaxnet:feature_extractor_proto_cc"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
)
cc_test
(
name
=
"term_map_utils_test"
,
size
=
"small"
,
srcs
=
[
"term_map_utils_test.cc"
],
deps
=
[
":term_map_utils"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime/test:term_map_helpers"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"transition_system_traits"
,
srcs
=
[
"transition_system_traits.cc"
],
hdrs
=
[
"transition_system_traits.h"
],
deps
=
[
":attributes"
,
"//dragnn/protos:spec_proto_cc"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
)
cc_test
(
name
=
"transition_system_traits_test"
,
size
=
"small"
,
srcs
=
[
"transition_system_traits_test.cc"
],
deps
=
[
":transition_system_traits"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:spec_proto_cc"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"unicode_dictionary"
,
srcs
=
[
"unicode_dictionary.cc"
],
hdrs
=
[
"unicode_dictionary.h"
],
deps
=
[
"//syntaxnet:base"
,
"//syntaxnet:term_frequency_map"
,
"//util/utf8:unicodetext"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
)
cc_test
(
name
=
"unicode_dictionary_test"
,
size
=
"small"
,
timeout
=
"moderate"
,
srcs
=
[
"unicode_dictionary_test.cc"
],
deps
=
[
":unicode_dictionary"
,
"//dragnn/core/test:generic"
,
"//dragnn/runtime/test:term_map_helpers"
,
"//syntaxnet:base"
,
"//syntaxnet:term_frequency_map"
,
"//third_party/utf"
,
"//util/utf8:unicodetext"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"sequence_extractor"
,
srcs
=
[
"sequence_extractor.cc"
],
hdrs
=
[
"sequence_extractor.h"
],
deps
=
[
"//dragnn/core:input_batch_cache"
,
"//dragnn/protos:spec_proto_cc"
,
"//syntaxnet:base"
,
"//syntaxnet:registry"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
)
cc_test
(
name
=
"sequence_extractor_test"
,
size
=
"small"
,
srcs
=
[
"sequence_extractor_test.cc"
],
deps
=
[
":sequence_extractor"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:spec_proto_cc"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"term_map_sequence_extractor"
,
hdrs
=
[
"term_map_sequence_extractor.h"
],
deps
=
[
":sequence_extractor"
,
":term_map_utils"
,
"//dragnn/protos:spec_proto_cc"
,
"//syntaxnet:base"
,
"//syntaxnet:shared_store"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
)
cc_test
(
name
=
"term_map_sequence_extractor_test"
,
size
=
"small"
,
srcs
=
[
"term_map_sequence_extractor_test.cc"
],
deps
=
[
":term_map_sequence_extractor"
,
"//dragnn/core:input_batch_cache"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime/test:term_map_helpers"
,
"//syntaxnet:base"
,
"//syntaxnet:term_frequency_map"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"syntaxnet_character_sequence_extractor"
,
srcs
=
[
"syntaxnet_character_sequence_extractor.cc"
],
deps
=
[
":sequence_extractor"
,
":term_map_sequence_extractor"
,
":term_map_utils"
,
":transition_system_traits"
,
":unicode_dictionary"
,
"//dragnn/core:input_batch_cache"
,
"//dragnn/io:sentence_input_batch"
,
"//dragnn/io:syntaxnet_sentence"
,
"//dragnn/protos:spec_proto_cc"
,
"//syntaxnet:base"
,
"//syntaxnet:segmenter_utils"
,
"//util/utf8:unicodetext"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
alwayslink
=
1
,
)
cc_test
(
name
=
"syntaxnet_character_sequence_extractor_test"
,
size
=
"small"
,
srcs
=
[
"syntaxnet_character_sequence_extractor_test.cc"
],
deps
=
[
":sequence_extractor"
,
":syntaxnet_character_sequence_extractor"
,
"//dragnn/core:input_batch_cache"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime/test:term_map_helpers"
,
"//syntaxnet:base"
,
"//syntaxnet:sentence_proto_cc"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"syntaxnet_word_sequence_extractor"
,
srcs
=
[
"syntaxnet_word_sequence_extractor.cc"
],
deps
=
[
":sequence_extractor"
,
":term_map_sequence_extractor"
,
":term_map_utils"
,
":transition_system_traits"
,
"//dragnn/core:input_batch_cache"
,
"//dragnn/io:sentence_input_batch"
,
"//dragnn/io:syntaxnet_sentence"
,
"//dragnn/protos:spec_proto_cc"
,
"//syntaxnet:base"
,
"//syntaxnet:term_frequency_map"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
alwayslink
=
1
,
)
cc_test
(
name
=
"syntaxnet_word_sequence_extractor_test"
,
size
=
"small"
,
srcs
=
[
"syntaxnet_word_sequence_extractor_test.cc"
],
deps
=
[
":sequence_extractor"
,
":syntaxnet_word_sequence_extractor"
,
"//dragnn/core:input_batch_cache"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime/test:term_map_helpers"
,
"//syntaxnet:base"
,
"//syntaxnet:sentence_proto_cc"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
dragnn_cc_multiarch_library
(
name
=
"sequence_features"
,
srcs
=
[
"sequence_features.cc"
],
hdrs
=
[
"sequence_features.h"
],
deps
=
[
":alignment"
,
":fixed_embeddings"
,
":sequence_extractor"
,
"//dragnn/core:input_batch_cache"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime/math:types"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
)
dragnn_cc_multiarch_test
(
name
=
"sequence_features_test"
,
size
=
"small"
,
srcs
=
[
"sequence_features_test.cc"
],
deps
=
[
":fixed_embeddings"
,
":sequence_extractor"
,
":sequence_features"
,
"//dragnn/core:input_batch_cache"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime/test:network_test_base"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"sequence_linker"
,
srcs
=
[
"sequence_linker.cc"
],
hdrs
=
[
"sequence_linker.h"
],
deps
=
[
"//dragnn/core:input_batch_cache"
,
"//dragnn/protos:spec_proto_cc"
,
"//syntaxnet:base"
,
"//syntaxnet:registry"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
)
cc_test
(
name
=
"sequence_linker_test"
,
size
=
"small"
,
srcs
=
[
"sequence_linker_test.cc"
],
deps
=
[
":sequence_linker"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:spec_proto_cc"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"identity_sequence_linker"
,
srcs
=
[
"identity_sequence_linker.cc"
],
deps
=
[
":sequence_linker"
,
":transition_system_traits"
,
"//dragnn/core:input_batch_cache"
,
"//dragnn/protos:spec_proto_cc"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
alwayslink
=
1
,
)
cc_test
(
name
=
"identity_sequence_linker_test"
,
size
=
"small"
,
srcs
=
[
"identity_sequence_linker_test.cc"
],
deps
=
[
":identity_sequence_linker"
,
":sequence_linker"
,
"//dragnn/core:input_batch_cache"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:spec_proto_cc"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"reversed_sequence_linker"
,
srcs
=
[
"reversed_sequence_linker.cc"
],
deps
=
[
":sequence_linker"
,
":transition_system_traits"
,
"//dragnn/core:input_batch_cache"
,
"//dragnn/protos:spec_proto_cc"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
alwayslink
=
1
,
)
cc_test
(
name
=
"reversed_sequence_linker_test"
,
size
=
"small"
,
srcs
=
[
"reversed_sequence_linker_test.cc"
],
deps
=
[
":reversed_sequence_linker"
,
":sequence_linker"
,
"//dragnn/core:input_batch_cache"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:spec_proto_cc"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"recurrent_sequence_linkers"
,
srcs
=
[
"recurrent_sequence_linkers.cc"
],
deps
=
[
":sequence_linker"
,
":transition_system_traits"
,
"//dragnn/core:input_batch_cache"
,
"//dragnn/protos:spec_proto_cc"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
alwayslink
=
1
,
)
cc_test
(
name
=
"recurrent_sequence_linkers_test"
,
size
=
"small"
,
srcs
=
[
"recurrent_sequence_linkers_test.cc"
],
deps
=
[
":recurrent_sequence_linkers"
,
":sequence_linker"
,
"//dragnn/core:input_batch_cache"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:spec_proto_cc"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"syntaxnet_character_sequence_linkers"
,
srcs
=
[
"syntaxnet_character_sequence_linkers.cc"
],
deps
=
[
":sequence_linker"
,
":transition_system_traits"
,
"//dragnn/core:input_batch_cache"
,
"//dragnn/io:sentence_input_batch"
,
"//dragnn/io:syntaxnet_sentence"
,
"//dragnn/protos:spec_proto_cc"
,
"//syntaxnet:base"
,
"//util/utf8:unicodetext"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
alwayslink
=
1
,
)
cc_test
(
name
=
"syntaxnet_character_sequence_linkers_test"
,
size
=
"small"
,
srcs
=
[
"syntaxnet_character_sequence_linkers_test.cc"
],
deps
=
[
":sequence_linker"
,
":syntaxnet_character_sequence_linkers"
,
"//dragnn/core:input_batch_cache"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:spec_proto_cc"
,
"//syntaxnet:base"
,
"//syntaxnet:sentence_proto_cc"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
dragnn_cc_multiarch_library
(
name
=
"sequence_links"
,
srcs
=
[
"sequence_links.cc"
],
hdrs
=
[
"sequence_links.h"
],
deps
=
[
":alignment"
,
":linked_embeddings"
,
":network_states"
,
":sequence_linker"
,
"//dragnn/core:input_batch_cache"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime/math:types"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
)
dragnn_cc_multiarch_test
(
name
=
"sequence_links_test"
,
size
=
"small"
,
srcs
=
[
"sequence_links_test.cc"
],
deps
=
[
":linked_embeddings"
,
":network_states"
,
":sequence_linker"
,
":sequence_links"
,
"//dragnn/core:input_batch_cache"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime/test:network_test_base"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"sequence_predictor"
,
srcs
=
[
"sequence_predictor.cc"
],
hdrs
=
[
"sequence_predictor.h"
],
deps
=
[
"//dragnn/core:input_batch_cache"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime/math:types"
,
"//syntaxnet:base"
,
"//syntaxnet:registry"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
)
cc_test
(
name
=
"sequence_predictor_test"
,
size
=
"small"
,
srcs
=
[
"sequence_predictor_test.cc"
],
deps
=
[
":sequence_predictor"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime/math:types"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"term_map_sequence_predictor"
,
srcs
=
[
"term_map_sequence_predictor.cc"
],
hdrs
=
[
"term_map_sequence_predictor.h"
],
deps
=
[
":sequence_predictor"
,
":term_map_utils"
,
"//dragnn/protos:spec_proto_cc"
,
"//syntaxnet:base"
,
"//syntaxnet:shared_store"
,
"//syntaxnet:term_frequency_map"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
)
cc_test
(
name
=
"term_map_sequence_predictor_test"
,
size
=
"small"
,
srcs
=
[
"term_map_sequence_predictor_test.cc"
],
deps
=
[
":term_map_sequence_predictor"
,
"//dragnn/core:input_batch_cache"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime/math:types"
,
"//dragnn/runtime/test:term_map_helpers"
,
"//syntaxnet:base"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"syntaxnet_tag_sequence_predictor"
,
srcs
=
[
"syntaxnet_tag_sequence_predictor.cc"
],
deps
=
[
":sequence_predictor"
,
":term_map_sequence_predictor"
,
":transition_system_traits"
,
"//dragnn/core:input_batch_cache"
,
"//dragnn/io:sentence_input_batch"
,
"//dragnn/io:syntaxnet_sentence"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime/math:types"
,
"//syntaxnet:base"
,
"//syntaxnet:sentence_proto_cc"
,
"//syntaxnet:term_frequency_map"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
alwayslink
=
1
,
)
cc_test
(
name
=
"syntaxnet_tag_sequence_predictor_test"
,
size
=
"small"
,
srcs
=
[
"syntaxnet_tag_sequence_predictor_test.cc"
],
deps
=
[
":alignment"
,
":sequence_predictor"
,
":syntaxnet_tag_sequence_predictor"
,
"//dragnn/core:input_batch_cache"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime/test:helpers"
,
"//dragnn/runtime/test:term_map_helpers"
,
"//syntaxnet:base"
,
"//syntaxnet:sentence_proto_cc"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"sequence_backend"
,
srcs
=
[
"sequence_backend.cc"
],
hdrs
=
[
"sequence_backend.h"
],
deps
=
[
"//dragnn/core:component_registry"
,
"//dragnn/core:input_batch_cache"
,
"//dragnn/core/interfaces:component"
,
"//dragnn/core/interfaces:transition_state"
,
"//dragnn/core/util:label"
,
"//dragnn/protos:data_proto_cc"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/protos:trace_proto_cc"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
alwayslink
=
1
,
)
cc_test
(
name
=
"sequence_backend_test"
,
size
=
"small"
,
srcs
=
[
"sequence_backend_test.cc"
],
deps
=
[
":sequence_backend"
,
"//dragnn/components/util:bulk_feature_extractor"
,
"//dragnn/core:input_batch_cache"
,
"//dragnn/core/interfaces:transition_state"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/protos:trace_proto_cc"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"select_best_component_transformer"
,
srcs
=
[
"select_best_component_transformer.cc"
],
deps
=
[
":component"
,
":component_transformation"
,
"//dragnn/protos:spec_proto_cc"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
alwayslink
=
1
,
)
cc_test
(
name
=
"select_best_component_transformer_test"
,
size
=
"small"
,
srcs
=
[
"select_best_component_transformer_test.cc"
],
deps
=
[
":component"
,
":component_transformation"
,
":extensions"
,
":select_best_component_transformer"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:spec_proto_cc"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"sequence_component_transformer"
,
srcs
=
[
"sequence_component_transformer.cc"
],
deps
=
[
":component_transformation"
,
":sequence_extractor"
,
":sequence_linker"
,
":sequence_predictor"
,
":transition_system_traits"
,
"//dragnn/protos:spec_proto_cc"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
alwayslink
=
1
,
)
cc_test
(
name
=
"sequence_component_transformer_test"
,
size
=
"small"
,
srcs
=
[
"sequence_component_transformer_test.cc"
],
deps
=
[
":component_transformation"
,
":sequence_component_transformer"
,
":sequence_extractor"
,
":sequence_linker"
,
":sequence_predictor"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime/math:types"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"stateless_component_transformer"
,
srcs
=
[
"stateless_component_transformer.cc"
],
deps
=
[
":component_transformation"
,
"//dragnn/protos:spec_proto_cc"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
alwayslink
=
1
,
)
cc_test
(
name
=
"stateless_component_transformer_test"
,
size
=
"small"
,
srcs
=
[
"stateless_component_transformer_test.cc"
],
deps
=
[
":component_transformation"
,
":stateless_component_transformer"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:spec_proto_cc"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"clear_dropout_component_transformer"
,
srcs
=
[
"clear_dropout_component_transformer.cc"
],
deps
=
[
":component_transformation"
,
"//dragnn/protos:spec_proto_cc"
,
"//syntaxnet:base"
,
"//syntaxnet:feature_extractor_proto_cc"
,
"//syntaxnet:fml_parser"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
alwayslink
=
1
,
)
cc_test
(
name
=
"clear_dropout_component_transformer_test"
,
size
=
"small"
,
srcs
=
[
"clear_dropout_component_transformer_test.cc"
],
deps
=
[
":clear_dropout_component_transformer"
,
":component_transformation"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:spec_proto_cc"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
dragnn_cc_multiarch_library
(
name
=
"sequence_model"
,
srcs
=
[
"sequence_model.cc"
],
hdrs
=
[
"sequence_model.h"
],
deps
=
[
":attributes"
,
":fixed_embeddings"
,
":linked_embeddings"
,
":network_states"
,
":sequence_backend"
,
":sequence_features"
,
":sequence_links"
,
":sequence_predictor"
,
":session_state"
,
":transition_system_traits"
,
"//dragnn/core:compute_session"
,
"//dragnn/core:input_batch_cache"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime/math:types"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
)
dragnn_cc_multiarch_test
(
name
=
"sequence_model_test"
,
size
=
"small"
,
srcs
=
[
"sequence_model_test.cc"
],
deps
=
[
":fixed_embeddings"
,
":linked_embeddings"
,
":network_states"
,
":sequence_backend"
,
":sequence_extractor"
,
":sequence_linker"
,
":sequence_model"
,
":sequence_predictor"
,
":session_state"
,
"//dragnn/core:input_batch_cache"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime/math:types"
,
"//dragnn/runtime/test:network_test_base"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
dragnn_cc_multiarch_library
(
name
=
"biaffine_digraph_component"
,
srcs
=
[
"biaffine_digraph_component.cc"
],
copts
=
FAST_MATH_COPTS
,
deps
=
[
":component"
,
":extensions"
,
":network_states"
,
":network_unit"
,
":session_state"
,
":variable_store"
,
"//dragnn/core:compute_session"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/protos:trace_proto_cc"
,
"//dragnn/runtime/math:eigen"
,
"//dragnn/runtime/math:types"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
alwayslink
=
1
,
)
dragnn_cc_multiarch_test
(
name
=
"biaffine_digraph_component_test"
,
size
=
"small"
,
srcs
=
[
"biaffine_digraph_component_test.cc"
],
deps
=
[
":biaffine_digraph_component"
,
":component"
,
":extensions"
,
":network_states"
,
":session_state"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime/math:types"
,
"//dragnn/runtime/test:network_test_base"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"head_selection_component_base"
,
srcs
=
[
"head_selection_component_base.cc"
],
hdrs
=
[
"head_selection_component_base.h"
],
deps
=
[
":alignment"
,
":component"
,
":extensions"
,
":network_states"
,
":session_state"
,
":variable_store"
,
"//dragnn/core:compute_session"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/protos:trace_proto_cc"
,
"//dragnn/runtime/math:types"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
)
cc_test
(
name
=
"head_selection_component_base_test"
,
size
=
"small"
,
srcs
=
[
"head_selection_component_base_test.cc"
],
deps
=
[
":head_selection_component_base"
,
":network_states"
,
":session_state"
,
":variable_store"
,
"//dragnn/core:compute_session"
,
"//dragnn/core:input_batch_cache"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/protos:trace_proto_cc"
,
"//dragnn/runtime/test:network_test_base"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"syntaxnet_head_selection_component"
,
srcs
=
[
"syntaxnet_head_selection_component.cc"
],
deps
=
[
":head_selection_component_base"
,
":session_state"
,
"//dragnn/core:compute_session"
,
"//dragnn/core:input_batch_cache"
,
"//dragnn/io:sentence_input_batch"
,
"//dragnn/io:syntaxnet_sentence"
,
"//dragnn/protos:trace_proto_cc"
,
"//syntaxnet:base"
,
"//syntaxnet:sentence_proto_cc"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
alwayslink
=
1
,
)
cc_test
(
name
=
"syntaxnet_head_selection_component_test"
,
size
=
"small"
,
srcs
=
[
"syntaxnet_head_selection_component_test.cc"
],
deps
=
[
":component"
,
":syntaxnet_head_selection_component"
,
":variable_store"
,
"//dragnn/core:compute_session"
,
"//dragnn/core:input_batch_cache"
,
"//dragnn/core/test:generic"
,
"//dragnn/io:sentence_input_batch"
,
"//dragnn/io:syntaxnet_sentence"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime/math:types"
,
"//dragnn/runtime/test:network_test_base"
,
"//syntaxnet:sentence_proto_cc"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"mst_solver_component_base"
,
srcs
=
[
"mst_solver_component_base.cc"
],
hdrs
=
[
"mst_solver_component_base.h"
],
deps
=
[
":attributes"
,
":component"
,
":extensions"
,
":network_states"
,
":network_unit"
,
":session_state"
,
":variable_store"
,
"//dragnn/core:compute_session"
,
"//dragnn/mst:mst_solver"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime/math:types"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
)
cc_test
(
name
=
"mst_solver_component_base_test"
,
size
=
"small"
,
srcs
=
[
"mst_solver_component_base_test.cc"
],
deps
=
[
":mst_solver_component_base"
,
":network_states"
,
":session_state"
,
":variable_store"
,
"//dragnn/core:compute_session"
,
"//dragnn/core:input_batch_cache"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/protos:trace_proto_cc"
,
"//dragnn/runtime/test:network_test_base"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"syntaxnet_mst_solver_component"
,
srcs
=
[
"syntaxnet_mst_solver_component.cc"
],
deps
=
[
":mst_solver_component_base"
,
":session_state"
,
"//dragnn/core:compute_session"
,
"//dragnn/core:input_batch_cache"
,
"//dragnn/io:sentence_input_batch"
,
"//dragnn/io:syntaxnet_sentence"
,
"//dragnn/protos:trace_proto_cc"
,
"//syntaxnet:base"
,
"//syntaxnet:sentence_proto_cc"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
alwayslink
=
1
,
)
cc_test
(
name
=
"syntaxnet_mst_solver_component_test"
,
size
=
"small"
,
srcs
=
[
"syntaxnet_mst_solver_component_test.cc"
],
deps
=
[
":component"
,
":syntaxnet_mst_solver_component"
,
":variable_store"
,
"//dragnn/core:compute_session"
,
"//dragnn/core:input_batch_cache"
,
"//dragnn/core/test:generic"
,
"//dragnn/io:sentence_input_batch"
,
"//dragnn/io:syntaxnet_sentence"
,
"//dragnn/protos:spec_proto_cc"
,
"//dragnn/runtime/math:types"
,
"//dragnn/runtime/test:network_test_base"
,
"//syntaxnet:sentence_proto_cc"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"converter_main"
,
srcs
=
[
"converter.cc"
],
deps
=
[
":component_transformation"
,
":conversion"
,
"//dragnn/runtime/myelin:myelination"
,
"//dragnn/runtime/xla:xla_compilation"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@sling//sling/base"
,
],
)
dragnn_cc_multiarch_binary
(
name
=
"converter"
,
target_arch
=
"generic"
,
deps
=
[
":biaffine_digraph_component"
,
":bulk_dynamic_component"
,
":bulk_feed_forward_network"
,
":bulk_lstm_network"
,
":clear_dropout_component_transformer"
,
":converter_main"
,
":dynamic_component"
,
":feed_forward_network"
,
":identity_sequence_linker"
,
":lstm_network"
,
":recurrent_sequence_linkers"
,
":reversed_sequence_linker"
,
":select_best_component_transformer"
,
":sequence_backend"
,
":sequence_bulk_dynamic_component"
,
":sequence_component_transformer"
,
":stateless_component_transformer"
,
":syntaxnet_character_sequence_extractor"
,
":syntaxnet_character_sequence_linkers"
,
":syntaxnet_head_selection_component"
,
":syntaxnet_mst_solver_component"
,
":syntaxnet_tag_sequence_predictor"
,
":syntaxnet_word_sequence_extractor"
,
"//dragnn/components/stateless:stateless_component"
,
"//dragnn/components/syntaxnet:syntaxnet_component"
,
"//dragnn/mst:mst_ops_cc"
,
"//dragnn/runtime/myelin:myelin_dynamic_component"
,
"//dragnn/runtime/myelin:sequence_myelin_dynamic_component"
,
"//dragnn/runtime/xla:xla_dynamic_component"
,
"//syntaxnet:parser_transitions"
,
],
)
sh_test
(
name
=
"converter_test"
,
size
=
"medium"
,
srcs
=
[
"converter_test.sh"
],
data
=
[
":converter"
]
+
glob
([
"testdata/converter_output/**"
,
"testdata/rnn_tagger/**"
,
]),
)
research/syntaxnet/dragnn/runtime/activation_functions.h
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Definitions of activation functions for neural netowrks.
#ifndef DRAGNN_RUNTIME_ACTIVATION_FUNCTIONS_H_
#define DRAGNN_RUNTIME_ACTIVATION_FUNCTIONS_H_
#include "dragnn/runtime/math/arithmetic.h"
#include "dragnn/runtime/math/types.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// Possible types of activation functions.
//
// TODO(googleuser): If many activation functions are added, or if functions start
// using configuration parameters (e.g., leakiness of a leaky ReLU), then switch
// to a registered class.
enum
class
ActivationFunction
{
kIdentity
,
// pass-through, useful for classification logits
kRelu
,
// ReLU; i.e., max(0,x)
};
// Applies the |activation_function| to the |values|.
template
<
class
T
>
void
ApplyActivationFunction
(
ActivationFunction
activation_function
,
MutableVector
<
T
>
values
);
// Implementation details below.
template
<
class
T
>
void
ApplyActivationFunction
(
ActivationFunction
activation_function
,
MutableVector
<
T
>
values
)
{
switch
(
activation_function
)
{
case
ActivationFunction
::
kIdentity
:
break
;
case
ActivationFunction
::
kRelu
:
MaxElements
(
T
(),
values
);
break
;
}
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_ACTIVATION_FUNCTIONS_H_
research/syntaxnet/dragnn/runtime/activation_functions_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/activation_functions.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/test/helpers.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Tests that kIdentity is a pass-through.
TEST
(
ActivationFunctionsTest
,
ApplyIdentity
)
{
UniqueVector
<
float
>
values
({
1.25
f
,
-
1.5
f
,
0.0
f
,
0.0625
f
,
-
0.03125
});
ApplyActivationFunction
(
ActivationFunction
::
kIdentity
,
*
values
);
EXPECT_EQ
((
*
values
)[
0
],
1.25
);
EXPECT_EQ
((
*
values
)[
1
],
-
1.5
);
EXPECT_EQ
((
*
values
)[
2
],
0.0
);
EXPECT_EQ
((
*
values
)[
3
],
0.0625
);
EXPECT_EQ
((
*
values
)[
4
],
-
0.03125
);
}
// Tests that kRelu clips to zero.
TEST
(
ActivationFunctionsTest
,
ApplyRelu
)
{
UniqueVector
<
float
>
values
({
1.25
f
,
-
1.5
f
,
0.0
f
,
0.0625
f
,
-
0.03125
});
ApplyActivationFunction
(
ActivationFunction
::
kRelu
,
*
values
);
EXPECT_EQ
((
*
values
)[
0
],
1.25
);
EXPECT_EQ
((
*
values
)[
1
],
0.0
);
// clipped
EXPECT_EQ
((
*
values
)[
2
],
0.0
);
// boundary
EXPECT_EQ
((
*
values
)[
3
],
0.0625
);
EXPECT_EQ
((
*
values
)[
4
],
0.0
);
// clipped
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/alignment.h
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Utils for working with aligned memory blocks. The DRAGNN runtime requires
// aligned memory for use in vectorized math. Do not rely on any particular
// value of the alignment requirement, because it will vary over time and in
// different build configurations.
#ifndef DRAGNN_RUNTIME_ALIGNMENT_H_
#define DRAGNN_RUNTIME_ALIGNMENT_H_
#include <stddef.h>
#include <stdint.h>
#include <string.h>
#include <memory>
#include <type_traits>
#include <vector>
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
// This is a type that has some private methods (so non-POD), but is known to be
// trivially-deconstructable. Ergo we add some special handling so
// IsAlignable<bfloat16> returns true.
namespace
tensorflow
{
struct
bfloat16
;
}
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// Returns true if |T| can be used in an aligned memory block.
template
<
class
T
>
constexpr
bool
IsAlignable
();
// Returns OK iff the |pointer| satisfies the alignment requirement.
tensorflow
::
Status
OkIfAligned
(
const
void
*
pointer
);
// Returns the next alignment boundary at or after the |byte_offset|.
size_t
PadToAlignment
(
size_t
byte_offset
);
// As above, but for pointers.
template
<
class
T
>
T
*
PadToAlignment
(
T
*
pointer
);
// Returns the number of bytes required to store a sequence of |num_arrays|
// aligned arrays of |array_size| bytes, including alignment padding. See
// (Mutable)AlignedArea below.
size_t
ComputeAlignedAreaSize
(
size_t
num_arrays
,
size_t
array_size
);
// Returns the number of bytes required to store a sequence of byte arrays of
// the given |sizes|, including alignment padding after each array.
size_t
ComputeTotalBytesWithAlignmentPadding
(
const
std
::
vector
<
size_t
>
&
sizes
);
// Forward-declared for friendship below.
class
Operands
;
class
UniqueAlignedArray
;
enum
class
BlockedMatrixFormat
;
namespace
internal
{
// A non-owning view of an aligned byte array. Templated so const and mutable
// versions can share implementation. Do not use this class directly, instead
// use (Mutable)AlignedView below.
template
<
class
Byte
>
class
AlignedViewImpl
{
public:
static_assert
(
sizeof
(
Byte
)
==
1
,
"Byte must be byte-sized"
);
// Creates an empty view.
AlignedViewImpl
()
=
default
;
// Points this at the same bytes as |that|, possibly reinterpreting type.
template
<
class
OtherByte
>
explicit
AlignedViewImpl
(
AlignedViewImpl
<
OtherByte
>
that
);
template
<
class
OtherByte
>
AlignedViewImpl
&
operator
=
(
AlignedViewImpl
<
OtherByte
>
that
);
// Points this at [|data|,|data|+|size|). On error, returns non-OK and
// modifies nothing.
tensorflow
::
Status
Reset
(
Byte
*
data
,
size_t
size
);
// Splits this into a list of |views| of the |sizes|, possibly reinterpreting
// type. The |views| need not completely cover all bytes of this. Requires
// that this spans ComputeTotalBytesWithAlignmentPadding(|sizes|) bytes. On
// error, returns non-OK and modifies nothing.
template
<
class
OtherByte
>
tensorflow
::
Status
Split
(
const
std
::
vector
<
size_t
>
&
sizes
,
std
::
vector
<
AlignedViewImpl
<
OtherByte
>>
*
views
)
const
;
// Accessors.
Byte
*
data
()
const
{
return
data_
;
}
size_t
size
()
const
{
return
size_
;
}
bool
empty
()
const
{
return
size
()
==
0
;
}
private:
template
<
class
OtherByte
>
friend
class
AlignedViewImpl
;
template
<
class
OtherByte
>
friend
class
AlignedAreaImpl
;
friend
Operands
;
friend
UniqueAlignedArray
;
// Directly creates an aligned view, bypassing alignment checks.
AlignedViewImpl
(
Byte
*
data
,
size_t
size
);
// Pointer to the start of the view.
Byte
*
data_
=
nullptr
;
// Number of bytes in the view.
size_t
size_
=
0
;
};
// A non-owning view of an aligned, 2-dimensional byte array. Templated so
// const and mutable versons can share implementation. Do not use this class
// directly, instead use (Mutable)AlignedArea below.
template
<
class
Byte
>
class
AlignedAreaImpl
{
public:
static_assert
(
sizeof
(
Byte
)
==
1
,
"Byte must be byte-sized"
);
// Creates an empty area.
AlignedAreaImpl
()
=
default
;
// Points this at the same bytes as |that|, possibly reinterpreting type.
template
<
class
OtherByte
>
explicit
AlignedAreaImpl
(
AlignedAreaImpl
<
OtherByte
>
that
);
template
<
class
OtherByte
>
AlignedAreaImpl
&
operator
=
(
AlignedAreaImpl
<
OtherByte
>
that
);
// Resets this to a sequence of |num_views| aligned sub-views of the |view|,
// each |view_size| bytes wide. The first sub-view covers [0,|view_size|) of
// |view|, and each subsequent sub-view starts at the next alignment boundary.
// Requires that |view| spans ComputeAlignedAreaSize(|num_views|,|view_size|)
// bytes or more. On error, returns non-OK and modifies nothing.
template
<
class
OtherByte
>
tensorflow
::
Status
Reset
(
AlignedViewImpl
<
OtherByte
>
view
,
size_t
num_views
,
size_t
view_size
);
// Accessors.
AlignedViewImpl
<
Byte
>
view
(
size_t
index
)
const
;
Byte
*
data
()
const
{
return
data_
;
}
size_t
num_views
()
const
{
return
num_views_
;
}
size_t
view_size
()
const
{
return
view_size_
;
}
size_t
view_stride
()
const
{
return
view_stride_
;
}
bool
empty
()
const
{
return
num_views
()
==
0
;
}
private:
template
<
class
OtherByte
>
friend
class
AlignedAreaImpl
;
friend
Operands
;
// Directly creates an aligned view, bypassing alignment checks.
AlignedAreaImpl
(
Byte
*
data
,
size_t
num_views
,
size_t
view_size
,
size_t
view_stride
);
// Pointer to the start of the first view.
Byte
*
data_
=
nullptr
;
// Number of views in the area.
size_t
num_views_
=
0
;
// Size of each view in bytes, excluding alignment padding.
size_t
view_size_
=
0
;
// Number of bytes between the starts of consecutive views. NB: This is not
// necessarily equal to PadToAlignment(|view_size_|).
size_t
view_stride_
=
0
;
};
}
// namespace internal
// Public aliases; use these.
using
AlignedView
=
internal
::
AlignedViewImpl
<
const
char
>
;
using
AlignedArea
=
internal
::
AlignedAreaImpl
<
const
char
>
;
using
MutableAlignedView
=
internal
::
AlignedViewImpl
<
char
>
;
using
MutableAlignedArea
=
internal
::
AlignedAreaImpl
<
char
>
;
// A uniquely-owned aligned byte array.
class
UniqueAlignedArray
{
public:
// Creates an empty byte array.
UniqueAlignedArray
()
=
default
;
// Reallocates this to |new_size| bytes, and discards the current byte array.
// Contents are uninitialized.
void
Reset
(
size_t
new_size
);
// Like Reset(), but only reallocates if |new_size| is more than the current
// capacity. NB: Does not preserve current content when reallocation occurs;
// use Resize() if that is desired.
void
Reserve
(
size_t
new_size
);
// Resizes this to contain |new_size| bytes, preserving current content. If
// |new_size| exceeds the current size, the added bytes are uninitialized. If
// |new_size| exceeds the current capacity, reallocates, and copies current
// content. Returns true if reallocation occurred.
bool
Resize
(
size_t
new_size
);
// Returns the aligned byte array.
MutableAlignedView
view
()
const
{
return
view_
;
}
private:
// Underlying byte array, which is padded for alignment.
std
::
unique_ptr
<
char
[]
>
padded_array_
;
// Size of the aligned portion of |padded_array_|.
size_t
capacity_
=
0
;
// Active range of the |storage_|.
MutableAlignedView
view_
;
};
// Implementation details below.
namespace
internal
{
// Required alignment for memory blocks. Only the runtime framework should use
// this; otherwise, DO NOT access or otherwise depend on this value.
enum
:
size_t
{
kAlignmentBytes
=
32
};
}
// namespace internal
template
<
class
T
>
constexpr
bool
IsAlignable
()
{
// Either T is divisible into alignment windows, or an alignment window is
// divisible into Ts. Likewise for T's alignment requirement. Finally, T
// must be POD because we won't call its constructor or destructor.
return
(
sizeof
(
T
)
%
internal
::
kAlignmentBytes
==
0
||
internal
::
kAlignmentBytes
%
sizeof
(
T
)
==
0
)
&&
(
alignof
(
T
)
%
internal
::
kAlignmentBytes
==
0
||
internal
::
kAlignmentBytes
%
alignof
(
T
)
==
0
)
&&
(
std
::
is_pod
<
T
>::
value
||
std
::
is_same
<
T
,
tensorflow
::
bfloat16
>::
value
);
}
inline
tensorflow
::
Status
OkIfAligned
(
const
void
*
pointer
)
{
const
uintptr_t
address
=
reinterpret_cast
<
uintptr_t
>
(
pointer
);
if
(
address
%
internal
::
kAlignmentBytes
!=
0
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Pointer fails alignment requirement: "
,
address
,
" vs required "
,
internal
::
kAlignmentBytes
);
}
return
tensorflow
::
Status
::
OK
();
}
inline
size_t
PadToAlignment
(
size_t
byte_offset
)
{
// Round up to the next alignment boundary by incrementing by a certain amount
// and then rounding down. Note that the bitmask clears the low-order bits of
// the offset, effectively rounding down to the previous alignment boundary.
return
(
byte_offset
+
internal
::
kAlignmentBytes
-
1
)
&
~
(
internal
::
kAlignmentBytes
-
1
);
}
template
<
class
T
>
T
*
PadToAlignment
(
T
*
pointer
)
{
static_assert
(
IsAlignable
<
T
>
(),
"T is not alignable"
);
uintptr_t
address
=
reinterpret_cast
<
uintptr_t
>
(
pointer
);
address
=
(
address
+
internal
::
kAlignmentBytes
-
1
)
&
~
(
internal
::
kAlignmentBytes
-
1
);
return
reinterpret_cast
<
T
*>
(
address
);
}
inline
size_t
ComputeAlignedAreaSize
(
size_t
num_arrays
,
size_t
array_size
)
{
return
num_arrays
*
PadToAlignment
(
array_size
);
}
inline
size_t
ComputeTotalBytesWithAlignmentPadding
(
const
std
::
vector
<
size_t
>
&
sizes
)
{
size_t
total
=
0
;
for
(
const
size_t
size
:
sizes
)
total
+=
PadToAlignment
(
size
);
return
total
;
}
namespace
internal
{
template
<
class
Byte
>
template
<
class
OtherByte
>
AlignedViewImpl
<
Byte
>::
AlignedViewImpl
(
AlignedViewImpl
<
OtherByte
>
that
)
:
data_
(
reinterpret_cast
<
Byte
*>
(
that
.
data
())),
size_
(
that
.
size
())
{}
template
<
class
Byte
>
template
<
class
OtherByte
>
AlignedViewImpl
<
Byte
>
&
AlignedViewImpl
<
Byte
>::
operator
=
(
AlignedViewImpl
<
OtherByte
>
that
)
{
data_
=
reinterpret_cast
<
Byte
*>
(
that
.
data
());
size_
=
that
.
size
();
return
*
this
;
}
template
<
class
Byte
>
tensorflow
::
Status
AlignedViewImpl
<
Byte
>::
Reset
(
Byte
*
data
,
size_t
size
)
{
TF_RETURN_IF_ERROR
(
OkIfAligned
(
data
));
// Success; make modifications.
data_
=
data
;
size_
=
size
;
return
tensorflow
::
Status
::
OK
();
}
template
<
class
Byte
>
template
<
class
OtherByte
>
tensorflow
::
Status
AlignedViewImpl
<
Byte
>::
Split
(
const
std
::
vector
<
size_t
>
&
sizes
,
std
::
vector
<
AlignedViewImpl
<
OtherByte
>>
*
views
)
const
{
const
size_t
total_bytes
=
ComputeTotalBytesWithAlignmentPadding
(
sizes
);
if
(
size
()
<
total_bytes
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"View is too small to be split into sizes ["
,
tensorflow
::
str_util
::
Join
(
sizes
,
", "
),
"]: need "
,
total_bytes
,
" bytes but have "
,
size
(),
" bytes"
);
}
// Success; make modifications.
views
->
clear
();
views
->
reserve
(
sizes
.
size
());
Byte
*
base
=
data
();
for
(
const
size_t
size
:
sizes
)
{
views
->
push_back
(
AlignedViewImpl
<
OtherByte
>
(
base
,
size
));
base
=
PadToAlignment
(
base
+
size
);
}
DCHECK_EQ
(
base
-
data
(),
total_bytes
);
return
tensorflow
::
Status
::
OK
();
}
template
<
class
Byte
>
AlignedViewImpl
<
Byte
>::
AlignedViewImpl
(
Byte
*
data
,
size_t
size
)
:
data_
(
data
),
size_
(
size
)
{
TF_DCHECK_OK
(
OkIfAligned
(
data_
));
}
template
<
class
Byte
>
template
<
class
OtherByte
>
AlignedAreaImpl
<
Byte
>::
AlignedAreaImpl
(
AlignedAreaImpl
<
OtherByte
>
that
)
:
data_
(
reinterpret_cast
<
Byte
*>
(
that
.
data_
)),
num_views_
(
that
.
num_views
()),
view_size_
(
that
.
view_size
()),
view_stride_
(
that
.
view_stride_
)
{}
template
<
class
Byte
>
template
<
class
OtherByte
>
AlignedAreaImpl
<
Byte
>
&
AlignedAreaImpl
<
Byte
>::
operator
=
(
AlignedAreaImpl
<
OtherByte
>
that
)
{
data_
=
reinterpret_cast
<
Byte
*>
(
that
.
data_
);
num_views_
=
that
.
num_views
();
view_size_
=
that
.
view_size
();
view_stride_
=
that
.
view_stride_
;
return
*
this
;
}
template
<
class
Byte
>
template
<
class
OtherByte
>
tensorflow
::
Status
AlignedAreaImpl
<
Byte
>::
Reset
(
AlignedViewImpl
<
OtherByte
>
view
,
size_t
num_views
,
size_t
view_size
)
{
const
size_t
total_bytes
=
ComputeAlignedAreaSize
(
num_views
,
view_size
);
if
(
view
.
size
()
<
total_bytes
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"View is too small for area of "
,
num_views
,
" views of "
,
view_size
,
" bytes: need "
,
total_bytes
,
" bytes but got "
,
view
.
size
(),
" bytes"
);
}
// Success; make modifications.
data_
=
reinterpret_cast
<
Byte
*>
(
view
.
data
());
num_views_
=
num_views
;
view_size_
=
view_size
;
view_stride_
=
PadToAlignment
(
view_size_
);
return
tensorflow
::
Status
::
OK
();
}
template
<
class
Byte
>
AlignedViewImpl
<
Byte
>
AlignedAreaImpl
<
Byte
>::
view
(
size_t
index
)
const
{
DCHECK_LT
(
index
,
num_views
());
return
AlignedViewImpl
<
Byte
>
(
data_
+
view_stride_
*
index
,
view_size_
);
}
template
<
class
Byte
>
AlignedAreaImpl
<
Byte
>::
AlignedAreaImpl
(
Byte
*
data
,
size_t
num_views
,
size_t
view_size
,
size_t
view_stride
)
:
data_
(
data
),
num_views_
(
num_views
),
view_size_
(
view_size
),
view_stride_
(
view_stride
)
{
TF_DCHECK_OK
(
OkIfAligned
(
data_
));
TF_DCHECK_OK
(
OkIfAligned
(
static_cast
<
const
char
*>
(
nullptr
)
+
view_stride_
));
}
}
// namespace internal
inline
void
UniqueAlignedArray
::
Reset
(
size_t
new_size
)
{
// Pad the |new_size| to the next alignment boundary, so the final bytes of
// the array are still in a full alignment window. E.g., if we resize to 48
// bytes with 32-byte alignment, then we allocate 64 bytes so the final 16
// bytes are still part of a full 32-byte alignment window.
const
size_t
aligned_size
=
PadToAlignment
(
new_size
);
// To obtain an aligned address, allocate a sufficiently-padded byte array and
// find an aligned address near the start of the block.
//
// TODO(googleuser): Alternatively, we could use library functions such as
// memalign(), posix_memalign(), or aligned_alloc(), but those may not be
// present on all platforms. Consider adding some #ifs to allow use of those
// library functions when available.
padded_array_
.
reset
(
new
char
[
aligned_size
+
internal
::
kAlignmentBytes
-
1
]);
capacity_
=
aligned_size
;
view_
.
size_
=
new_size
;
view_
.
data_
=
PadToAlignment
(
padded_array_
.
get
());
TF_DCHECK_OK
(
OkIfAligned
(
view_
.
data_
));
}
inline
void
UniqueAlignedArray
::
Reserve
(
size_t
new_size
)
{
if
(
new_size
>
capacity_
)
{
Reset
(
new_size
);
}
else
{
view_
.
size_
=
new_size
;
}
}
inline
bool
UniqueAlignedArray
::
Resize
(
size_t
new_size
)
{
// Avoid reallocation, if possible.
if
(
new_size
<=
capacity_
)
{
view_
.
size_
=
new_size
;
return
false
;
}
// Reallocate and copy. Extend the life of the old array until it is copied.
//
// Note: C realloc() can extend a byte array in place (i.e., without copying).
// Unfortunately, there is no aligned version of realloc(). Moreover, adding
// alignment padding could cause double-copying: first, when realloc() copies
// the data to the new buffer, and second, if the amount of padding required
// at the new address is not the same as before.
const
std
::
unique_ptr
<
char
[]
>
old_array
=
std
::
move
(
padded_array_
);
const
MutableAlignedView
old_view
=
view_
;
Reset
(
2
*
new_size
);
memcpy
(
view_
.
data
(),
old_view
.
data
(),
old_view
.
size
());
view_
.
size_
=
new_size
;
return
true
;
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_ALIGNMENT_H_
research/syntaxnet/dragnn/runtime/alignment_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/alignment.h"
#include <stddef.h>
#include <string>
#include <vector>
#include "dragnn/core/test/generic.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
static_assert
(
internal
::
kAlignmentBytes
>=
4
,
"alignment too small"
);
// Expects that two pointers have the same address.
void
ExpectSameAddress
(
const
void
*
pointer1
,
const
void
*
pointer2
)
{
EXPECT_EQ
(
pointer1
,
pointer2
);
}
// Tests that standard scalar types are alignable.
TEST
(
IsAlignableTest
,
Alignable
)
{
EXPECT_TRUE
(
IsAlignable
<
char
>
());
EXPECT_TRUE
(
IsAlignable
<
float
>
());
EXPECT_TRUE
(
IsAlignable
<
double
>
());
}
// Tests that objects of odd sizes are not alignable.
TEST
(
IsAlignableTest
,
NotAlignable
)
{
EXPECT_FALSE
(
IsAlignable
<
char
[
3
]
>
());
EXPECT_FALSE
(
IsAlignable
<
char
[
7
]
>
());
EXPECT_FALSE
(
IsAlignable
<
char
[
7919
]
>
());
}
// Tests that OkIfAligned() returns OK on aligned pointers.
TEST
(
OkIfAlignedTest
,
Aligned
)
{
const
char
*
ptr
=
nullptr
;
TF_EXPECT_OK
(
OkIfAligned
(
ptr
));
ptr
+=
internal
::
kAlignmentBytes
;
TF_EXPECT_OK
(
OkIfAligned
(
ptr
));
ptr
+=
123
*
internal
::
kAlignmentBytes
;
TF_EXPECT_OK
(
OkIfAligned
(
ptr
));
}
// Tests that OkIfAligned() returns non-OK on misaligned pointers.
TEST
(
OkIfAlignedTest
,
NotAligned
)
{
const
char
*
ptr
=
nullptr
;
EXPECT_THAT
(
OkIfAligned
(
ptr
+
1
),
test
::
IsErrorWithSubstr
(
"Pointer fails alignment requirement"
));
EXPECT_THAT
(
OkIfAligned
(
ptr
+
23
),
test
::
IsErrorWithSubstr
(
"Pointer fails alignment requirement"
));
}
// Tests that any window of |internal::kAlignmentBytes| bytes contains exactly
// one aligned address.
TEST
(
OkIfAlignedTest
,
OnePerAlignmentWindow
)
{
// Note that |bytes| does not necessarily start at an aligned address. Even
// so, it is still guaranteed to contain exactly one aligned address, in the
// same sense that any sequence of 10 consecutive integers contains exactly
// one whose decimal representation ends in '0'. This property is exploited
// in UniqueAlignedArray::Reset().
const
string
bytes
(
internal
::
kAlignmentBytes
,
' '
);
int
num_ok
=
0
;
for
(
int
i
=
0
;
i
<
bytes
.
size
();
++
i
)
{
if
(
OkIfAligned
(
bytes
.
data
()
+
i
).
ok
())
++
num_ok
;
}
EXPECT_EQ
(
num_ok
,
1
);
}
// Tests that PadToAlignment() produces an aligned byte offset.
TEST
(
PadToAlignmentTest
,
Offset
)
{
EXPECT_EQ
(
PadToAlignment
(
0
),
0
);
EXPECT_EQ
(
PadToAlignment
(
1
),
internal
::
kAlignmentBytes
);
EXPECT_EQ
(
PadToAlignment
(
internal
::
kAlignmentBytes
+
1
),
2
*
internal
::
kAlignmentBytes
);
EXPECT_EQ
(
PadToAlignment
(
99
*
internal
::
kAlignmentBytes
+
3
),
100
*
internal
::
kAlignmentBytes
);
}
// Tests that PadToAlignment() produces an aligned pointer.
TEST
(
PadToAlignmentTest
,
Pointer
)
{
const
string
bytes
=
"hello"
;
TF_EXPECT_OK
(
OkIfAligned
(
PadToAlignment
(
bytes
.
data
())));
const
std
::
vector
<
float
>
reals
(
10
);
TF_EXPECT_OK
(
OkIfAligned
(
PadToAlignment
(
reals
.
data
())));
}
// Tests that ComputeAlignedAreaSize() calculates the correct size.
TEST
(
ComputeAlignedAreaSizeTest
,
Basic
)
{
EXPECT_EQ
(
ComputeAlignedAreaSize
(
0
,
0
),
0
);
EXPECT_EQ
(
ComputeAlignedAreaSize
(
0
,
1
),
0
);
EXPECT_EQ
(
ComputeAlignedAreaSize
(
1
,
0
),
0
);
EXPECT_EQ
(
ComputeAlignedAreaSize
(
1
,
1
),
internal
::
kAlignmentBytes
);
EXPECT_EQ
(
ComputeAlignedAreaSize
(
1
,
internal
::
kAlignmentBytes
),
internal
::
kAlignmentBytes
);
EXPECT_EQ
(
ComputeAlignedAreaSize
(
3
,
internal
::
kAlignmentBytes
+
1
),
6
*
internal
::
kAlignmentBytes
);
EXPECT_EQ
(
ComputeAlignedAreaSize
(
11
,
internal
::
kAlignmentBytes
-
1
),
11
*
internal
::
kAlignmentBytes
);
EXPECT_EQ
(
ComputeAlignedAreaSize
(
7
,
internal
::
kAlignmentBytes
),
7
*
internal
::
kAlignmentBytes
);
}
// Tests that ComputeTotalBytesWithAlignmentPadding() calculates the correct
// total size.
TEST
(
ComputeTotalBytesWithAlignmentPaddingTest
,
DifferentSizes
)
{
EXPECT_EQ
(
ComputeTotalBytesWithAlignmentPadding
({}),
0
);
EXPECT_EQ
(
ComputeTotalBytesWithAlignmentPadding
({
0
}),
0
);
EXPECT_EQ
(
ComputeTotalBytesWithAlignmentPadding
({
0
,
0
,
0
}),
0
);
EXPECT_EQ
(
ComputeTotalBytesWithAlignmentPadding
({
1
}),
internal
::
kAlignmentBytes
);
EXPECT_EQ
(
ComputeTotalBytesWithAlignmentPadding
({
1
,
1
,
1
}),
3
*
internal
::
kAlignmentBytes
);
EXPECT_EQ
(
ComputeTotalBytesWithAlignmentPadding
(
{
1
,
internal
::
kAlignmentBytes
,
internal
::
kAlignmentBytes
+
1
}),
4
*
internal
::
kAlignmentBytes
);
std
::
vector
<
size_t
>
sizes
;
for
(
size_t
i
=
1
;
i
<=
internal
::
kAlignmentBytes
;
++
i
)
sizes
.
push_back
(
i
);
EXPECT_EQ
(
ComputeTotalBytesWithAlignmentPadding
(
sizes
),
internal
::
kAlignmentBytes
*
internal
::
kAlignmentBytes
);
}
// Tests that ComputeTotalBytesWithAlignmentPadding() is equivalent to
// ComputeAlignedAreaSize() when all sizes are equal.
TEST
(
ComputeTotalBytesWithAlignmentPaddingTest
,
AllSameSize
)
{
EXPECT_EQ
(
ComputeTotalBytesWithAlignmentPadding
({
1
,
1
,
1
,
1
}),
ComputeAlignedAreaSize
(
4
,
1
));
EXPECT_EQ
(
ComputeTotalBytesWithAlignmentPadding
({
7
,
7
,
7
,
7
,
7
,
7
}),
ComputeAlignedAreaSize
(
6
,
7
));
EXPECT_EQ
(
ComputeTotalBytesWithAlignmentPadding
({
77
,
77
,
77
}),
ComputeAlignedAreaSize
(
3
,
77
));
}
// Tests that UniqueAlignedArray is empty by default.
TEST
(
UniqueAlignedArrayTest
,
EmptyByDefault
)
{
UniqueAlignedArray
array
;
EXPECT_EQ
(
array
.
view
().
size
(),
0
);
EXPECT_TRUE
(
array
.
view
().
empty
());
}
// Tests that UniqueAlignedArray::Reset() always reallocates.
TEST
(
UniqueAlignedArrayTest
,
Reset
)
{
UniqueAlignedArray
array
;
// Reset to non-empty.
array
.
Reset
(
10
);
const
MutableAlignedView
view1
=
array
.
view
();
TF_EXPECT_OK
(
OkIfAligned
(
view1
.
data
()));
EXPECT_EQ
(
view1
.
size
(),
10
);
// Calling view() again should return the same byte array.
const
MutableAlignedView
view2
=
array
.
view
();
ExpectSameAddress
(
view2
.
data
(),
view1
.
data
());
EXPECT_EQ
(
view2
.
size
(),
view1
.
size
());
// Reset to a different size.
array
.
Reset
(
33
);
const
MutableAlignedView
view3
=
array
.
view
();
TF_EXPECT_OK
(
OkIfAligned
(
view3
.
data
()));
EXPECT_EQ
(
view3
.
size
(),
33
);
}
// Tests that UniqueAlignedArray::Reset() reallocates when growing.
TEST
(
UniqueAlignedArrayTest
,
Reserve
)
{
UniqueAlignedArray
array
;
// Reset to non-empty.
array
.
Reserve
(
20
);
const
MutableAlignedView
view1
=
array
.
view
();
TF_EXPECT_OK
(
OkIfAligned
(
view1
.
data
()));
EXPECT_EQ
(
view1
.
size
(),
20
);
// Shrink to a smaller size; should not reallocate.
array
.
Reserve
(
7
);
const
MutableAlignedView
view2
=
array
.
view
();
ExpectSameAddress
(
view2
.
data
(),
view1
.
data
());
EXPECT_EQ
(
view2
.
size
(),
7
);
// Grow but still remain within capacity; should not reallocate.
array
.
Reserve
(
14
);
const
MutableAlignedView
view3
=
array
.
view
();
ExpectSameAddress
(
view3
.
data
(),
view1
.
data
());
EXPECT_EQ
(
view3
.
size
(),
14
);
}
// Tests that UniqueAlignedArray::Resize() reallocates when growing and
// preserves existing contents.
TEST
(
UniqueAlignedArrayTest
,
Resize
)
{
UniqueAlignedArray
array
;
// Resize to non-empty.
EXPECT_TRUE
(
array
.
Resize
(
10
));
const
MutableAlignedView
view1
=
array
.
view
();
TF_EXPECT_OK
(
OkIfAligned
(
view1
.
data
()));
EXPECT_EQ
(
view1
.
size
(),
10
);
// Write some stuff.
for
(
int
i
=
0
;
i
<
10
;
++
i
)
view1
.
data
()[
i
]
=
'1'
;
// Resize to a larger size.
EXPECT_TRUE
(
array
.
Resize
(
33
));
const
MutableAlignedView
view2
=
array
.
view
();
TF_EXPECT_OK
(
OkIfAligned
(
view2
.
data
()));
EXPECT_EQ
(
view2
.
size
(),
33
);
// Check that content was preserved.
for
(
int
i
=
0
;
i
<
10
;
++
i
)
EXPECT_EQ
(
view2
.
data
()[
i
],
'1'
);
// Append more stuff.
for
(
int
i
=
10
;
i
<
33
;
++
i
)
view2
.
data
()[
i
]
=
'2'
;
// Resize to a smaller size.
EXPECT_FALSE
(
array
.
Resize
(
15
));
const
MutableAlignedView
view3
=
array
.
view
();
TF_EXPECT_OK
(
OkIfAligned
(
view3
.
data
()));
ExpectSameAddress
(
view3
.
data
(),
view2
.
data
());
EXPECT_EQ
(
view3
.
size
(),
15
);
// Check that content was preserved.
for
(
int
i
=
0
;
i
<
10
;
++
i
)
EXPECT_EQ
(
view3
.
data
()[
i
],
'1'
);
for
(
int
i
=
10
;
i
<
15
;
++
i
)
EXPECT_EQ
(
view3
.
data
()[
i
],
'2'
);
// Overwrite with new stuff.
for
(
int
i
=
0
;
i
<
15
;
++
i
)
view3
.
data
()[
i
]
=
'3'
;
// Resize to a larger size, but still below capacity.
EXPECT_FALSE
(
array
.
Resize
(
20
));
const
MutableAlignedView
view4
=
array
.
view
();
TF_EXPECT_OK
(
OkIfAligned
(
view4
.
data
()));
ExpectSameAddress
(
view4
.
data
(),
view2
.
data
());
EXPECT_EQ
(
view4
.
size
(),
20
);
// Check that content was preserved.
for
(
int
i
=
0
;
i
<
15
;
++
i
)
EXPECT_EQ
(
view4
.
data
()[
i
],
'3'
);
}
// Tests that (Mutable)AlignedView is empty by default.
TEST
(
AlignedViewTest
,
EmptyByDefault
)
{
AlignedView
view1
;
EXPECT_EQ
(
view1
.
size
(),
0
);
EXPECT_TRUE
(
view1
.
empty
());
MutableAlignedView
view2
;
EXPECT_EQ
(
view2
.
size
(),
0
);
EXPECT_TRUE
(
view2
.
empty
());
}
// Tests that (Mutable)AlignedView::Reset() works on aligned pointers.
TEST
(
AlignedViewTest
,
ResetValid
)
{
char
*
pointer
=
nullptr
;
pointer
+=
3
*
internal
::
kAlignmentBytes
;
AlignedView
view1
;
TF_EXPECT_OK
(
view1
.
Reset
(
pointer
,
100
));
ExpectSameAddress
(
view1
.
data
(),
pointer
);
EXPECT_EQ
(
view1
.
size
(),
100
);
EXPECT_FALSE
(
view1
.
empty
());
MutableAlignedView
view2
;
TF_EXPECT_OK
(
view2
.
Reset
(
pointer
,
100
));
ExpectSameAddress
(
view2
.
data
(),
pointer
);
EXPECT_EQ
(
view2
.
size
(),
100
);
EXPECT_FALSE
(
view2
.
empty
());
}
// Tests that (Mutable)AlignedView::Reset() fails on misaligned pointers.
TEST
(
AlignedViewTest
,
ResetInvalid
)
{
char
*
pointer
=
nullptr
;
++
pointer
;
// not aligned
AlignedView
view1
;
EXPECT_THAT
(
view1
.
Reset
(
pointer
,
10
),
test
::
IsErrorWithSubstr
(
"Pointer fails alignment requirement"
));
MutableAlignedView
view2
;
EXPECT_THAT
(
view2
.
Reset
(
pointer
,
10
),
test
::
IsErrorWithSubstr
(
"Pointer fails alignment requirement"
));
}
// Tests that (Mutable)AlignedView::Reset() can empty the view.
TEST
(
AlignedViewTest
,
ResetEmpty
)
{
char
*
pointer
=
nullptr
;
pointer
+=
11
*
internal
::
kAlignmentBytes
;
// First point to a non-empty byte array.
AlignedView
view1
;
TF_EXPECT_OK
(
view1
.
Reset
(
pointer
,
100
));
ExpectSameAddress
(
view1
.
data
(),
pointer
);
EXPECT_EQ
(
view1
.
size
(),
100
);
EXPECT_FALSE
(
view1
.
empty
());
// Then reset to empty.
TF_EXPECT_OK
(
view1
.
Reset
(
pointer
,
0
));
EXPECT_EQ
(
view1
.
size
(),
0
);
EXPECT_TRUE
(
view1
.
empty
());
// First point to a non-empty byte array.
MutableAlignedView
view2
;
TF_EXPECT_OK
(
view2
.
Reset
(
pointer
,
100
));
ExpectSameAddress
(
view2
.
data
(),
pointer
);
EXPECT_EQ
(
view2
.
size
(),
100
);
EXPECT_FALSE
(
view2
.
empty
());
// Then reset to empty.
TF_EXPECT_OK
(
view2
.
Reset
(
pointer
,
0
));
EXPECT_EQ
(
view2
.
size
(),
0
);
EXPECT_TRUE
(
view2
.
empty
());
}
// Tests that (Mutable)AlignedView supports copy-construction and assignment
// with shallow-copy semantics, and reinterprets from char* to const char*.
TEST
(
AlignedViewTest
,
CopyAndAssign
)
{
char
*
pointer1
=
nullptr
;
pointer1
+=
3
*
internal
::
kAlignmentBytes
;
const
char
*
pointer2
=
nullptr
;
pointer2
+=
7
*
internal
::
kAlignmentBytes
;
MutableAlignedView
view1
;
TF_ASSERT_OK
(
view1
.
Reset
(
pointer1
,
100
));
AlignedView
view2
;
TF_ASSERT_OK
(
view2
.
Reset
(
pointer2
,
200
));
MutableAlignedView
view3
(
view1
);
ExpectSameAddress
(
view3
.
data
(),
pointer1
);
EXPECT_EQ
(
view3
.
size
(),
100
);
EXPECT_FALSE
(
view3
.
empty
());
view3
=
MutableAlignedView
();
EXPECT_EQ
(
view3
.
size
(),
0
);
EXPECT_TRUE
(
view3
.
empty
());
view3
=
view1
;
ExpectSameAddress
(
view3
.
data
(),
pointer1
);
EXPECT_EQ
(
view3
.
size
(),
100
);
EXPECT_FALSE
(
view3
.
empty
());
AlignedView
view4
(
view1
);
// reinterprets type
ExpectSameAddress
(
view4
.
data
(),
pointer1
);
EXPECT_EQ
(
view4
.
size
(),
100
);
EXPECT_FALSE
(
view4
.
empty
());
view4
=
AlignedView
();
EXPECT_EQ
(
view4
.
size
(),
0
);
EXPECT_TRUE
(
view4
.
empty
());
view4
=
view2
;
ExpectSameAddress
(
view4
.
data
(),
pointer2
);
EXPECT_EQ
(
view4
.
size
(),
200
);
EXPECT_FALSE
(
view4
.
empty
());
view4
=
view1
;
// reinterprets type
ExpectSameAddress
(
view4
.
data
(),
pointer1
);
EXPECT_EQ
(
view4
.
size
(),
100
);
EXPECT_FALSE
(
view4
.
empty
());
view4
=
MutableAlignedView
();
// reinterprets type
EXPECT_EQ
(
view4
.
size
(),
0
);
EXPECT_TRUE
(
view4
.
empty
());
}
// Tests that AlignedView can split itself into sub-views with specified sizes.
TEST
(
AlignedViewTest
,
SplitConst
)
{
const
std
::
vector
<
size_t
>
sizes
=
{
1
,
internal
::
kAlignmentBytes
,
internal
::
kAlignmentBytes
+
1
,
1
,
123
};
const
size_t
total_bytes
=
ComputeTotalBytesWithAlignmentPadding
(
sizes
);
AlignedView
view
;
TF_ASSERT_OK
(
view
.
Reset
(
nullptr
,
total_bytes
));
std
::
vector
<
AlignedView
>
views
(
100
);
// will be resized
TF_ASSERT_OK
(
view
.
Split
(
sizes
,
&
views
));
ASSERT_EQ
(
views
.
size
(),
5
);
const
char
*
base
=
view
.
data
();
ExpectSameAddress
(
views
[
0
].
data
(),
base
);
EXPECT_EQ
(
views
[
0
].
size
(),
1
);
base
+=
internal
::
kAlignmentBytes
;
ExpectSameAddress
(
views
[
1
].
data
(),
base
);
EXPECT_EQ
(
views
[
1
].
size
(),
internal
::
kAlignmentBytes
);
base
+=
internal
::
kAlignmentBytes
;
ExpectSameAddress
(
views
[
2
].
data
(),
base
);
EXPECT_EQ
(
views
[
2
].
size
(),
internal
::
kAlignmentBytes
+
1
);
base
+=
2
*
internal
::
kAlignmentBytes
;
ExpectSameAddress
(
views
[
3
].
data
(),
base
);
EXPECT_EQ
(
views
[
3
].
size
(),
1
);
base
+=
internal
::
kAlignmentBytes
;
ExpectSameAddress
(
views
[
4
].
data
(),
base
);
EXPECT_EQ
(
views
[
4
].
size
(),
123
);
}
// Tests that MutableAlignedView can split itself into sub-views with specified
// sizes, and reinterprets from char* to const char*.
TEST
(
AlignedViewTest
,
SplitMutable
)
{
const
std
::
vector
<
size_t
>
sizes
=
{
1
,
internal
::
kAlignmentBytes
,
internal
::
kAlignmentBytes
+
1
,
1
,
123
};
const
size_t
total_bytes
=
ComputeTotalBytesWithAlignmentPadding
(
sizes
);
// Also add some padding to check that we can split part of the view.
MutableAlignedView
view
;
TF_ASSERT_OK
(
view
.
Reset
(
nullptr
,
total_bytes
+
10
));
std
::
vector
<
AlignedView
>
const_views
(
99
);
// will be resized
std
::
vector
<
MutableAlignedView
>
mutable_views
(
2
);
// will be resized
TF_ASSERT_OK
(
view
.
Split
(
sizes
,
&
const_views
));
TF_ASSERT_OK
(
view
.
Split
(
sizes
,
&
mutable_views
));
ASSERT_EQ
(
const_views
.
size
(),
5
);
ASSERT_EQ
(
mutable_views
.
size
(),
5
);
const
char
*
base
=
view
.
data
();
ExpectSameAddress
(
const_views
[
0
].
data
(),
base
);
ExpectSameAddress
(
mutable_views
[
0
].
data
(),
base
);
EXPECT_EQ
(
const_views
[
0
].
size
(),
1
);
EXPECT_EQ
(
mutable_views
[
0
].
size
(),
1
);
base
+=
internal
::
kAlignmentBytes
;
ExpectSameAddress
(
const_views
[
1
].
data
(),
base
);
ExpectSameAddress
(
mutable_views
[
1
].
data
(),
base
);
EXPECT_EQ
(
const_views
[
1
].
size
(),
internal
::
kAlignmentBytes
);
EXPECT_EQ
(
mutable_views
[
1
].
size
(),
internal
::
kAlignmentBytes
);
base
+=
internal
::
kAlignmentBytes
;
ExpectSameAddress
(
const_views
[
2
].
data
(),
base
);
ExpectSameAddress
(
mutable_views
[
2
].
data
(),
base
);
EXPECT_EQ
(
const_views
[
2
].
size
(),
internal
::
kAlignmentBytes
+
1
);
EXPECT_EQ
(
mutable_views
[
2
].
size
(),
internal
::
kAlignmentBytes
+
1
);
base
+=
2
*
internal
::
kAlignmentBytes
;
ExpectSameAddress
(
const_views
[
3
].
data
(),
base
);
ExpectSameAddress
(
mutable_views
[
3
].
data
(),
base
);
EXPECT_EQ
(
const_views
[
3
].
size
(),
1
);
EXPECT_EQ
(
mutable_views
[
3
].
size
(),
1
);
base
+=
internal
::
kAlignmentBytes
;
ExpectSameAddress
(
const_views
[
4
].
data
(),
base
);
ExpectSameAddress
(
mutable_views
[
4
].
data
(),
base
);
EXPECT_EQ
(
const_views
[
4
].
size
(),
123
);
EXPECT_EQ
(
mutable_views
[
4
].
size
(),
123
);
}
TEST
(
AlignedViewTest
,
SplitTooSmall
)
{
const
std
::
vector
<
size_t
>
sizes
=
{
1
,
internal
::
kAlignmentBytes
,
internal
::
kAlignmentBytes
+
1
,
1
,
123
};
const
size_t
total_bytes
=
ComputeTotalBytesWithAlignmentPadding
(
sizes
);
// Make the view just a bit too small.
MutableAlignedView
view
;
TF_ASSERT_OK
(
view
.
Reset
(
nullptr
,
total_bytes
-
1
));
std
::
vector
<
MutableAlignedView
>
views
;
EXPECT_THAT
(
view
.
Split
(
sizes
,
&
views
),
test
::
IsErrorWithSubstr
(
"View is too small to be split"
));
}
// Tests that (Mutable)AlignedArea is empty by default.
TEST
(
AlignedAreaTest
,
EmptyByDefault
)
{
AlignedArea
area1
;
EXPECT_EQ
(
area1
.
num_views
(),
0
);
EXPECT_EQ
(
area1
.
view_size
(),
0
);
EXPECT_TRUE
(
area1
.
empty
());
MutableAlignedArea
area2
;
EXPECT_EQ
(
area2
.
num_views
(),
0
);
EXPECT_EQ
(
area2
.
view_size
(),
0
);
EXPECT_TRUE
(
area2
.
empty
());
}
// Tests that (Mutable)AlignedArea::Reset() can initialize to a single view.
TEST
(
AlignedAreaTest
,
ResetSingleton
)
{
const
char
*
pointer1
=
nullptr
;
pointer1
+=
3
*
internal
::
kAlignmentBytes
;
char
*
pointer2
=
nullptr
;
pointer2
+=
7
*
internal
::
kAlignmentBytes
;
AlignedView
view1
;
TF_ASSERT_OK
(
view1
.
Reset
(
pointer1
,
internal
::
kAlignmentBytes
));
MutableAlignedView
view2
;
TF_ASSERT_OK
(
view2
.
Reset
(
pointer2
,
internal
::
kAlignmentBytes
+
1
));
AlignedArea
area1
;
TF_ASSERT_OK
(
area1
.
Reset
(
view1
,
1
,
1
));
EXPECT_EQ
(
area1
.
num_views
(),
1
);
EXPECT_EQ
(
area1
.
view_size
(),
1
);
EXPECT_FALSE
(
area1
.
empty
());
ExpectSameAddress
(
area1
.
view
(
0
).
data
(),
pointer1
);
EXPECT_EQ
(
area1
.
view
(
0
).
size
(),
1
);
TF_ASSERT_OK
(
area1
.
Reset
(
view2
,
1
,
2
));
EXPECT_EQ
(
area1
.
num_views
(),
1
);
EXPECT_EQ
(
area1
.
view_size
(),
2
);
EXPECT_FALSE
(
area1
.
empty
());
ExpectSameAddress
(
area1
.
view
(
0
).
data
(),
pointer2
);
EXPECT_EQ
(
area1
.
view
(
0
).
size
(),
2
);
TF_ASSERT_OK
(
area1
.
Reset
(
view2
,
1
,
1
));
EXPECT_EQ
(
area1
.
num_views
(),
1
);
EXPECT_EQ
(
area1
.
view_size
(),
1
);
EXPECT_FALSE
(
area1
.
empty
());
ExpectSameAddress
(
area1
.
view
(
0
).
data
(),
pointer2
);
EXPECT_EQ
(
area1
.
view
(
0
).
size
(),
1
);
MutableAlignedArea
area2
;
TF_ASSERT_OK
(
area2
.
Reset
(
view2
,
1
,
2
));
EXPECT_EQ
(
area2
.
num_views
(),
1
);
EXPECT_EQ
(
area2
.
view_size
(),
2
);
EXPECT_FALSE
(
area2
.
empty
());
ExpectSameAddress
(
area2
.
view
(
0
).
data
(),
pointer2
);
EXPECT_EQ
(
area2
.
view
(
0
).
size
(),
2
);
TF_ASSERT_OK
(
area2
.
Reset
(
view2
,
1
,
1
));
EXPECT_EQ
(
area2
.
num_views
(),
1
);
EXPECT_EQ
(
area2
.
view_size
(),
1
);
EXPECT_FALSE
(
area2
.
empty
());
ExpectSameAddress
(
area2
.
view
(
0
).
data
(),
pointer2
);
EXPECT_EQ
(
area2
.
view
(
0
).
size
(),
1
);
}
// Tests that (Mutable)AlignedArea::Reset() can initialize to a sequence of
// multiple views.
TEST
(
AlignedAreaTest
,
ResetMultiple
)
{
const
char
*
pointer1
=
nullptr
;
pointer1
+=
3
*
internal
::
kAlignmentBytes
;
char
*
pointer2
=
nullptr
;
pointer2
+=
7
*
internal
::
kAlignmentBytes
;
AlignedView
view1
;
TF_ASSERT_OK
(
view1
.
Reset
(
pointer1
,
11
*
internal
::
kAlignmentBytes
));
MutableAlignedView
view2
;
TF_ASSERT_OK
(
view2
.
Reset
(
pointer2
,
2
*
internal
::
kAlignmentBytes
));
AlignedArea
area1
;
TF_ASSERT_OK
(
area1
.
Reset
(
view1
,
11
,
1
));
EXPECT_EQ
(
area1
.
num_views
(),
11
);
EXPECT_EQ
(
area1
.
view_size
(),
1
);
EXPECT_FALSE
(
area1
.
empty
());
for
(
int
i
=
0
;
i
<
area1
.
num_views
();
++
i
)
{
ExpectSameAddress
(
area1
.
view
(
i
).
data
(),
pointer1
+
internal
::
kAlignmentBytes
*
i
);
EXPECT_EQ
(
area1
.
view
(
i
).
size
(),
1
);
}
TF_ASSERT_OK
(
area1
.
Reset
(
view1
,
10
,
internal
::
kAlignmentBytes
));
EXPECT_EQ
(
area1
.
num_views
(),
10
);
EXPECT_EQ
(
area1
.
view_size
(),
internal
::
kAlignmentBytes
);
EXPECT_FALSE
(
area1
.
empty
());
for
(
int
i
=
0
;
i
<
area1
.
num_views
();
++
i
)
{
ExpectSameAddress
(
area1
.
view
(
i
).
data
(),
pointer1
+
internal
::
kAlignmentBytes
*
i
);
EXPECT_EQ
(
area1
.
view
(
i
).
size
(),
internal
::
kAlignmentBytes
);
}
TF_ASSERT_OK
(
area1
.
Reset
(
view2
,
2
,
2
));
EXPECT_EQ
(
area1
.
num_views
(),
2
);
EXPECT_EQ
(
area1
.
view_size
(),
2
);
EXPECT_FALSE
(
area1
.
empty
());
for
(
int
i
=
0
;
i
<
area1
.
num_views
();
++
i
)
{
ExpectSameAddress
(
area1
.
view
(
i
).
data
(),
pointer2
+
internal
::
kAlignmentBytes
*
i
);
EXPECT_EQ
(
area1
.
view
(
i
).
size
(),
2
);
}
MutableAlignedArea
area2
;
TF_ASSERT_OK
(
area2
.
Reset
(
view2
,
2
,
internal
::
kAlignmentBytes
));
EXPECT_EQ
(
area2
.
num_views
(),
2
);
EXPECT_EQ
(
area2
.
view_size
(),
internal
::
kAlignmentBytes
);
EXPECT_FALSE
(
area2
.
empty
());
for
(
int
i
=
0
;
i
<
area2
.
num_views
();
++
i
)
{
ExpectSameAddress
(
area2
.
view
(
i
).
data
(),
pointer2
+
internal
::
kAlignmentBytes
*
i
);
EXPECT_EQ
(
area2
.
view
(
i
).
size
(),
internal
::
kAlignmentBytes
);
}
}
// Tests that (Mutable)AlignedArea::Reset() fails when the view being split into
// sub-views is too small.
TEST
(
AlignedAreaTest
,
ResetInvalid
)
{
AlignedView
view1
;
TF_ASSERT_OK
(
view1
.
Reset
(
nullptr
,
11
*
internal
::
kAlignmentBytes
));
MutableAlignedView
view2
;
TF_ASSERT_OK
(
view2
.
Reset
(
nullptr
,
2
*
internal
::
kAlignmentBytes
));
// View size larger than available view.
AlignedArea
area
;
EXPECT_THAT
(
area
.
Reset
(
view1
,
1
,
11
*
internal
::
kAlignmentBytes
+
1
),
test
::
IsErrorWithSubstr
(
"View is too small for area"
));
TF_ASSERT_OK
(
area
.
Reset
(
view1
,
11
,
1
));
EXPECT_THAT
(
area
.
Reset
(
view2
,
1
,
2
*
internal
::
kAlignmentBytes
+
1
),
test
::
IsErrorWithSubstr
(
"View is too small for area"
));
TF_ASSERT_OK
(
area
.
Reset
(
view2
,
2
,
1
));
// Total size larger than available view.
EXPECT_THAT
(
area
.
Reset
(
view1
,
12
,
1
),
test
::
IsErrorWithSubstr
(
"View is too small for area"
));
TF_ASSERT_OK
(
area
.
Reset
(
view1
,
11
,
1
));
EXPECT_THAT
(
area
.
Reset
(
view1
,
4
,
2
*
internal
::
kAlignmentBytes
+
1
),
test
::
IsErrorWithSubstr
(
"View is too small for area"
));
TF_ASSERT_OK
(
area
.
Reset
(
view1
,
11
,
1
));
EXPECT_THAT
(
area
.
Reset
(
view1
,
3
,
3
*
internal
::
kAlignmentBytes
+
1
),
test
::
IsErrorWithSubstr
(
"View is too small for area"
));
TF_ASSERT_OK
(
area
.
Reset
(
view1
,
11
,
1
));
EXPECT_THAT
(
area
.
Reset
(
view1
,
2
,
5
*
internal
::
kAlignmentBytes
+
1
),
test
::
IsErrorWithSubstr
(
"View is too small for area"
));
TF_ASSERT_OK
(
area
.
Reset
(
view1
,
11
,
1
));
EXPECT_THAT
(
area
.
Reset
(
view2
,
3
,
1
),
test
::
IsErrorWithSubstr
(
"View is too small for area"
));
TF_ASSERT_OK
(
area
.
Reset
(
view2
,
2
,
1
));
EXPECT_THAT
(
area
.
Reset
(
view2
,
2
,
internal
::
kAlignmentBytes
+
1
),
test
::
IsErrorWithSubstr
(
"View is too small for area"
));
TF_ASSERT_OK
(
area
.
Reset
(
view2
,
2
,
1
));
}
// Tests that (Mutable)AlignedView::Reset() can empty the area.
TEST
(
AlignedAreaTest
,
ResetEmpty
)
{
AlignedView
view1
;
TF_ASSERT_OK
(
view1
.
Reset
(
nullptr
,
11
*
internal
::
kAlignmentBytes
));
MutableAlignedView
view2
;
TF_ASSERT_OK
(
view2
.
Reset
(
nullptr
,
2
*
internal
::
kAlignmentBytes
));
// First point to a non-empty byte array, then clear.
AlignedArea
area1
;
TF_ASSERT_OK
(
area1
.
Reset
(
view1
,
11
,
1
));
TF_ASSERT_OK
(
area1
.
Reset
(
view1
,
0
,
0
));
EXPECT_EQ
(
area1
.
num_views
(),
0
);
EXPECT_EQ
(
area1
.
view_size
(),
0
);
EXPECT_TRUE
(
area1
.
empty
());
TF_ASSERT_OK
(
area1
.
Reset
(
view2
,
2
,
1
));
TF_ASSERT_OK
(
area1
.
Reset
(
view2
,
0
,
100
));
EXPECT_EQ
(
area1
.
num_views
(),
0
);
EXPECT_EQ
(
area1
.
view_size
(),
100
);
EXPECT_TRUE
(
area1
.
empty
());
TF_ASSERT_OK
(
area1
.
Reset
(
view2
,
2
,
1
));
TF_ASSERT_OK
(
area1
.
Reset
(
MutableAlignedView
(),
0
,
1
));
EXPECT_EQ
(
area1
.
num_views
(),
0
);
EXPECT_EQ
(
area1
.
view_size
(),
1
);
EXPECT_TRUE
(
area1
.
empty
());
MutableAlignedArea
area2
;
TF_ASSERT_OK
(
area2
.
Reset
(
view2
,
2
,
1
));
TF_ASSERT_OK
(
area2
.
Reset
(
view2
,
0
,
0
));
EXPECT_EQ
(
area2
.
num_views
(),
0
);
EXPECT_EQ
(
area2
.
view_size
(),
0
);
EXPECT_TRUE
(
area2
.
empty
());
TF_ASSERT_OK
(
area2
.
Reset
(
view2
,
2
,
1
));
TF_ASSERT_OK
(
area2
.
Reset
(
view2
,
0
,
100
));
EXPECT_EQ
(
area2
.
num_views
(),
0
);
EXPECT_EQ
(
area2
.
view_size
(),
100
);
EXPECT_TRUE
(
area2
.
empty
());
TF_ASSERT_OK
(
area2
.
Reset
(
view2
,
2
,
1
));
TF_ASSERT_OK
(
area2
.
Reset
(
MutableAlignedView
(),
0
,
1
));
EXPECT_EQ
(
area2
.
num_views
(),
0
);
EXPECT_EQ
(
area2
.
view_size
(),
1
);
EXPECT_TRUE
(
area2
.
empty
());
}
// Tests that (Mutable)AlignedArea supports copy-construction and assignment
// with shallow-copy semantics, and reinterprets from char* to const char*.
TEST
(
AlignedAreaTest
,
CopyAndAssign
)
{
char
*
pointer1
=
nullptr
;
pointer1
+=
3
*
internal
::
kAlignmentBytes
;
const
char
*
pointer2
=
nullptr
;
pointer2
+=
7
*
internal
::
kAlignmentBytes
;
MutableAlignedView
view1
;
TF_ASSERT_OK
(
view1
.
Reset
(
pointer1
,
ComputeAlignedAreaSize
(
1
,
5
)));
AlignedView
view2
;
TF_ASSERT_OK
(
view2
.
Reset
(
pointer2
,
ComputeAlignedAreaSize
(
2
,
77
)));
MutableAlignedArea
area1
;
TF_ASSERT_OK
(
area1
.
Reset
(
view1
,
1
,
5
));
AlignedArea
area2
;
TF_ASSERT_OK
(
area2
.
Reset
(
view2
,
2
,
77
));
MutableAlignedArea
area3
(
area1
);
EXPECT_EQ
(
area3
.
num_views
(),
1
);
EXPECT_EQ
(
area3
.
view_size
(),
5
);
EXPECT_FALSE
(
area3
.
empty
());
ExpectSameAddress
(
area3
.
view
(
0
).
data
(),
pointer1
);
EXPECT_EQ
(
area3
.
view
(
0
).
size
(),
5
);
area3
=
MutableAlignedArea
();
EXPECT_EQ
(
area3
.
num_views
(),
0
);
EXPECT_EQ
(
area3
.
view_size
(),
0
);
EXPECT_TRUE
(
area3
.
empty
());
area3
=
area1
;
EXPECT_EQ
(
area3
.
num_views
(),
1
);
EXPECT_EQ
(
area3
.
view_size
(),
5
);
EXPECT_FALSE
(
area3
.
empty
());
ExpectSameAddress
(
area3
.
view
(
0
).
data
(),
pointer1
);
EXPECT_EQ
(
area3
.
view
(
0
).
size
(),
5
);
AlignedArea
area4
(
area1
);
// reinterprets type
EXPECT_EQ
(
area4
.
num_views
(),
1
);
EXPECT_EQ
(
area4
.
view_size
(),
5
);
EXPECT_FALSE
(
area4
.
empty
());
ExpectSameAddress
(
area4
.
view
(
0
).
data
(),
pointer1
);
EXPECT_EQ
(
area4
.
view
(
0
).
size
(),
5
);
area4
=
AlignedArea
();
EXPECT_EQ
(
area4
.
num_views
(),
0
);
EXPECT_EQ
(
area4
.
view_size
(),
0
);
EXPECT_TRUE
(
area4
.
empty
());
area4
=
area2
;
EXPECT_EQ
(
area4
.
num_views
(),
2
);
EXPECT_EQ
(
area4
.
view_size
(),
77
);
EXPECT_FALSE
(
area4
.
empty
());
ExpectSameAddress
(
area4
.
view
(
0
).
data
(),
pointer2
);
EXPECT_EQ
(
area4
.
view
(
0
).
size
(),
77
);
ExpectSameAddress
(
area4
.
view
(
1
).
data
(),
PadToAlignment
(
pointer2
+
77
));
EXPECT_EQ
(
area4
.
view
(
1
).
size
(),
77
);
area4
=
area1
;
// reinterprets type
EXPECT_EQ
(
area4
.
num_views
(),
1
);
EXPECT_EQ
(
area4
.
view_size
(),
5
);
EXPECT_FALSE
(
area4
.
empty
());
ExpectSameAddress
(
area4
.
view
(
0
).
data
(),
pointer1
);
EXPECT_EQ
(
area4
.
view
(
0
).
size
(),
5
);
area4
=
MutableAlignedArea
();
// reinterprets type
EXPECT_EQ
(
area4
.
num_views
(),
0
);
EXPECT_EQ
(
area4
.
view_size
(),
0
);
EXPECT_TRUE
(
area4
.
empty
());
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/array_variable_store.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/array_variable_store.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/cpu_info.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// Increment this if the serialized format changes in an incompatible way that
// can't be detected through other means. For example,
// * If kAlignmentBytes is changed, then kVersion need not change because there
// is a separate field for detecting alignment mismatch.
// * If ArrayVariableStoreSpec.variable is no longer populated, perhaps replaced
// by some other approach, then kVersion should be incremented.
const
uint32
ArrayVariableStore
::
kVersion
=
0
;
tensorflow
::
Status
ArrayVariableStore
::
Reset
(
const
ArrayVariableStoreSpec
&
spec
,
AlignedView
data
)
{
if
(
!
spec
.
has_version
()
||
!
spec
.
has_alignment_bytes
()
||
!
spec
.
has_is_little_endian
())
{
return
tensorflow
::
errors
::
InvalidArgument
(
"ArrayVariableStoreSpec is missing a required field: "
,
spec
.
ShortDebugString
());
}
if
(
spec
.
version
()
!=
kVersion
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"ArrayVariableStoreSpec.version ("
,
spec
.
version
(),
") does not match the binary ("
,
kVersion
,
")"
);
}
if
(
spec
.
alignment_bytes
()
!=
internal
::
kAlignmentBytes
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"ArrayVariableStoreSpec.alignment_bytes ("
,
spec
.
alignment_bytes
(),
") does not match the binary ("
,
internal
::
kAlignmentBytes
,
")"
);
}
// TODO(googleuser): It should be possible to correct an endian-ness mismatch.
// A rough outline is:
// * VariableStore::Lookup() takes an additional argument set to sizeof(T).
// * Capture sizeof(T) and write it into the VariableSpec.
// * Detect endian mismatch and byte-swap variables with multi-byte types.
if
(
spec
.
is_little_endian
()
!=
tensorflow
::
port
::
kLittleEndian
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"ArrayVariableStoreSpec.is_little_endian ("
,
spec
.
is_little_endian
(),
") does not match the binary ("
,
tensorflow
::
port
::
kLittleEndian
,
")"
);
}
for
(
const
VariableSpec
&
variable_spec
:
spec
.
variable
())
{
// When the proto parser encounters an unknown enumerator on the wire, it
// replaces it with the default value (i.e., FORMAT_UNKNOWN). Therefore,
// VariableSpec.format() will always return a valid enumerator.
DCHECK
(
VariableSpec
::
Format_IsValid
(
variable_spec
.
format
()));
if
(
variable_spec
.
format
()
==
VariableSpec
::
FORMAT_UNKNOWN
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Unknown variable format: "
,
variable_spec
.
ShortDebugString
());
}
if
(
variable_spec
.
format
()
==
VariableSpec
::
FORMAT_FLAT
&&
variable_spec
.
num_views
()
!=
1
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Flat variables must have 1 view: "
,
variable_spec
.
ShortDebugString
());
}
}
// Build into a temp mapping to avoid modification on error.
std
::
unique_ptr
<
std
::
map
<
Key
,
Value
>>
new_variables
(
new
std
::
map
<
Key
,
Value
>
());
// Slice sub-arrays off of the main byte array.
const
char
*
base
=
data
.
data
();
const
char
*
const
end
=
base
+
data
.
size
();
for
(
const
VariableSpec
&
variable_spec
:
spec
.
variable
())
{
const
size_t
num_views
=
variable_spec
.
num_views
();
const
size_t
view_size
=
variable_spec
.
view_size
();
const
size_t
area_size
=
ComputeAlignedAreaSize
(
num_views
,
view_size
);
if
(
base
+
area_size
>
end
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Variable would overrun main byte array: "
,
variable_spec
.
ShortDebugString
());
}
AlignedView
view
;
TF_RETURN_IF_ERROR
(
view
.
Reset
(
base
,
area_size
));
base
+=
area_size
;
// remove claimed slice
// Set dimensions from the spec.
std
::
vector
<
size_t
>
dimensions
(
variable_spec
.
dimension
().
begin
(),
variable_spec
.
dimension
().
end
());
Value
value
(
std
::
move
(
dimensions
),
AlignedArea
());
AlignedArea
&
area
=
value
.
second
;
TF_RETURN_IF_ERROR
(
area
.
Reset
(
view
,
num_views
,
view_size
));
// Currently, blocked variables are meant for fast inference algorithms,
// which do not tolerate padding. Raise errors if there is padding.
if
(
variable_spec
.
format
()
==
VariableSpec
::
FORMAT_COLUMN_BLOCKED_ROW_MAJOR_MATRIX
)
{
size_t
padding
=
variable_spec
.
view_size
()
%
internal
::
kAlignmentBytes
;
if
(
padding
!=
0
)
{
return
tensorflow
::
errors
::
Internal
(
"Currently, fast matrix-vector operations do not support padded "
"blocked matrices, but variable '"
,
variable_spec
.
name
(),
"' has padding "
,
padding
);
}
}
const
Key
key
(
variable_spec
.
name
(),
variable_spec
.
format
());
if
(
!
new_variables
->
emplace
(
key
,
value
).
second
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Duplicate variable: "
,
variable_spec
.
ShortDebugString
());
}
}
if
(
base
!=
end
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Variables do not completely cover main byte array: "
,
end
-
base
,
" bytes remaining"
);
}
// Success; make modifications.
variables_
=
std
::
move
(
new_variables
);
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
ArrayVariableStore
::
Lookup
(
const
string
&
name
,
VariableSpec
::
Format
format
,
std
::
vector
<
size_t
>
*
dimensions
,
AlignedArea
*
area
)
{
if
(
!
variables_
)
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"ArrayVariableStore not initialized"
);
}
const
Key
key
(
name
,
format
);
const
auto
it
=
variables_
->
find
(
key
);
if
(
it
==
variables_
->
end
())
{
return
tensorflow
::
errors
::
NotFound
(
"ArrayVariableStore has no variable with name '"
,
name
,
"' and format "
,
VariableSpec
::
Format_Name
(
format
));
}
// Success; make modifications.
const
Value
&
value
=
it
->
second
;
*
dimensions
=
value
.
first
;
*
area
=
value
.
second
;
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
ArrayVariableStore
::
Close
()
{
if
(
!
variables_
)
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"ArrayVariableStore not initialized"
);
}
variables_
.
reset
();
return
tensorflow
::
Status
::
OK
();
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/array_variable_store.h
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_ARRAY_VARIABLE_STORE_H_
#define DRAGNN_RUNTIME_ARRAY_VARIABLE_STORE_H_
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "dragnn/protos/runtime.pb.h"
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// A variable store that groups all variables into a single byte array. This
// class and its subclasses are intended for use in production.
//
// Each variable occupies a sub-array of the main byte array. The mapping from
// the name and format of a variable to the sub-array containing its content is
// defined in ArrayVariableStoreSpec. The variables may appear in any order.
//
// This format allows variables to be mapped directly into memory, which reduces
// initialization time and supports usage on-device, where mmap() is effectively
// obligatory for large data resources.
class
ArrayVariableStore
:
public
VariableStore
{
public:
// Creates an uninitialized store.
ArrayVariableStore
()
=
default
;
// Resets this to represent the variables defined by the |spec| and |data|.
// The |data| must remain valid until this is destroyed or Reset(). (Note
// that subclasses have simpler lifetime requirements). On error, returns
// non-OK and modifies nothing.
tensorflow
::
Status
Reset
(
const
ArrayVariableStoreSpec
&
spec
,
AlignedView
data
);
// Implements VariableStore.
using
VariableStore
::
Lookup
;
// import Lookup<T>() convenience methods
tensorflow
::
Status
Lookup
(
const
string
&
name
,
VariableSpec
::
Format
format
,
std
::
vector
<
size_t
>
*
dimensions
,
AlignedArea
*
area
)
override
;
tensorflow
::
Status
Close
()
override
;
private:
friend
class
ArrayVariableStoreBuilder
;
// for access to kVersion
// The current version of the serialized format.
static
const
uint32
kVersion
;
// A (name,format) key associated with a variable.
using
Key
=
std
::
pair
<
string
,
VariableSpec
::
Format
>
;
// Dimension vector and aligned area.
using
Value
=
std
::
pair
<
const
std
::
vector
<
size_t
>
,
AlignedArea
>
;
// Mapping from variable key to variable content. Initially null, filled in
// Reset(), and deleted in Close(). Wrapped in std::unique_ptr so the entire
// mapping can be deleted.
std
::
unique_ptr
<
std
::
map
<
Key
,
Value
>>
variables_
;
};
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_ARRAY_VARIABLE_STORE_H_
research/syntaxnet/dragnn/runtime/array_variable_store_builder.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/array_variable_store_builder.h"
#include <stddef.h>
#include <tuple>
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/array_variable_store.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/cpu_info.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Appends the content of the |view| to the |data|, followed by zero-padding to
// the next alignment boundary.
void
Append
(
AlignedView
view
,
string
*
data
)
{
DCHECK_EQ
(
PadToAlignment
(
data
->
size
()),
data
->
size
());
const
size_t
alignment_padding
=
PadToAlignment
(
view
.
size
())
-
view
.
size
();
data
->
append
(
view
.
data
(),
view
.
size
());
data
->
append
(
alignment_padding
,
'\0'
);
}
// As above, but for an aligned |area|.
void
Append
(
AlignedArea
area
,
string
*
data
)
{
DCHECK_EQ
(
PadToAlignment
(
data
->
size
()),
data
->
size
());
const
size_t
orig_size
=
data
->
size
();
for
(
size_t
i
=
0
;
i
<
area
.
num_views
();
++
i
)
Append
(
area
.
view
(
i
),
data
);
DCHECK_EQ
(
data
->
size
()
-
orig_size
,
ComputeAlignedAreaSize
(
area
.
num_views
(),
area
.
view_size
()));
}
}
// namespace
tensorflow
::
Status
ArrayVariableStoreBuilder
::
Build
(
const
Variables
&
variables
,
ArrayVariableStoreSpec
*
spec
,
string
*
data
)
{
data
->
clear
();
spec
->
Clear
();
spec
->
set_version
(
ArrayVariableStore
::
kVersion
);
spec
->
set_alignment_bytes
(
internal
::
kAlignmentBytes
);
spec
->
set_is_little_endian
(
tensorflow
::
port
::
kLittleEndian
);
for
(
const
auto
&
variable
:
variables
)
{
string
name
;
VariableSpec
::
Format
format
;
std
::
vector
<
size_t
>
dimensions
;
AlignedArea
area
;
std
::
tie
(
name
,
format
)
=
variable
.
first
;
std
::
tie
(
dimensions
,
area
)
=
variable
.
second
;
if
(
format
==
VariableSpec
::
FORMAT_FLAT
&&
area
.
num_views
()
!=
1
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Flat variables must have 1 view, but '"
,
name
,
"' has "
,
area
.
num_views
());
}
VariableSpec
*
variable_spec
=
spec
->
add_variable
();
variable_spec
->
set_name
(
name
);
variable_spec
->
set_format
(
format
);
variable_spec
->
set_num_views
(
area
.
num_views
());
variable_spec
->
set_view_size
(
area
.
view_size
());
for
(
size_t
dimension
:
dimensions
)
{
variable_spec
->
add_dimension
(
dimension
);
}
Append
(
area
,
data
);
}
return
tensorflow
::
Status
::
OK
();
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/array_variable_store_builder.h
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_ARRAY_VARIABLE_STORE_BUILDER_H_
#define DRAGNN_RUNTIME_ARRAY_VARIABLE_STORE_BUILDER_H_
#include <map>
#include <string>
#include "dragnn/protos/runtime.pb.h"
#include "dragnn/runtime/variable_store_wrappers.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// Utils for converting a set of variables into a byte array that can be loaded
// by ArrayVariableStore. See that class for details on the required format.
class
ArrayVariableStoreBuilder
{
public:
using
Variables
=
CaptureUsedVariableStoreWrapper
::
Variables
;
// Forbids instantiation; pure static class.
ArrayVariableStoreBuilder
()
=
delete
;
~
ArrayVariableStoreBuilder
()
=
delete
;
// Overwrites the |data| with a byte array that represents the |variables|,
// and overwrites the |spec| with the associated configuration. On error,
// returns non-OK.
static
tensorflow
::
Status
Build
(
const
Variables
&
variables
,
ArrayVariableStoreSpec
*
spec
,
string
*
data
);
};
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_ARRAY_VARIABLE_STORE_BUILDER_H_
research/syntaxnet/dragnn/runtime/array_variable_store_builder_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/array_variable_store_builder.h"
#include <stddef.h>
#include <map>
#include <string>
#include <vector>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/runtime.pb.h"
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/test/helpers.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Tests that the builder rejects invalid flat variables.
TEST
(
ArrayVariableStoreBuilderTest
,
InvalidFlatVariable
)
{
AlignedView
view
;
ArrayVariableStoreBuilder
::
Variables
variables
;
ArrayVariableStoreSpec
spec
;
string
data
;
TF_ASSERT_OK
(
view
.
Reset
(
nullptr
,
2
*
internal
::
kAlignmentBytes
));
// Try an empty area.
std
::
pair
<
string
,
VariableSpec
::
Format
>
foo_key
(
"foo"
,
VariableSpec
::
FORMAT_FLAT
);
AlignedArea
area
;
TF_ASSERT_OK
(
area
.
Reset
(
view
,
0
,
0
));
std
::
pair
<
std
::
vector
<
size_t
>
,
AlignedArea
>
foo_value
({
1
},
area
);
variables
.
push_back
(
std
::
make_pair
(
foo_key
,
foo_value
));
EXPECT_THAT
(
ArrayVariableStoreBuilder
::
Build
(
variables
,
&
spec
,
&
data
),
test
::
IsErrorWithSubstr
(
"Flat variables must have 1 view, but 'foo' has 0"
));
// Try an area with more than 1 sub-view.
TF_ASSERT_OK
(
area
.
Reset
(
view
,
2
,
0
));
variables
[
0
].
second
.
second
=
area
;
EXPECT_THAT
(
ArrayVariableStoreBuilder
::
Build
(
variables
,
&
spec
,
&
data
),
test
::
IsErrorWithSubstr
(
"Flat variables must have 1 view, but 'foo' has 2"
));
}
// Tests that the builder succeeds on good inputs and reproduces an expected
// byte array.
//
// NB: Since this test directly compares the byte array, it implicitly requires
// that the builder lays out the variables in a particular order. If that order
// changes, the test expectations must be updated.
TEST
(
ArrayVariableStoreBuilderTest
,
RegressionTest
)
{
const
string
kLocalSpecPath
=
"dragnn/runtime/testdata/array_variable_store_spec"
;
const
string
kLocalDataPath
=
"dragnn/runtime/testdata/array_variable_store_data"
;
const
string
kExpectedSpecPath
=
tensorflow
::
io
::
JoinPath
(
test
::
GetTestDataPrefix
(),
"dragnn/runtime/testdata/array_variable_store_spec"
);
const
string
kExpectedDataPath
=
tensorflow
::
io
::
JoinPath
(
test
::
GetTestDataPrefix
(),
"dragnn/runtime/testdata/array_variable_store_data"
);
// If these values are changed, make sure to rewrite the test data and update
// array_variable_store_test.cc.
UniqueMatrix
<
float
>
foo
({{
0.0
,
0.5
,
1.0
},
//
{
1.5
,
2.0
,
2.5
},
//
{
3.0
,
3.5
,
4.0
},
//
{
4.5
,
5.0
,
5.5
}});
UniqueMatrix
<
double
>
baz_data
({{
1.0
,
2.0
,
2.0
,
2.0
},
//
{
3.0
,
4.0
,
4.0
,
4.0
},
//
{
5.0
,
6.0
,
6.0
,
6.0
},
//
{
7.0
,
8.0
,
8.0
,
8.0
}});
ArrayVariableStoreBuilder
::
Variables
variables
;
std
::
pair
<
string
,
VariableSpec
::
Format
>
foo_key
(
"foo"
,
VariableSpec
::
FORMAT_ROW_MAJOR_MATRIX
);
std
::
pair
<
std
::
vector
<
size_t
>
,
AlignedArea
>
foo_value
(
{
foo
->
num_rows
(),
foo
->
num_columns
()},
AlignedArea
(
foo
.
area
()));
variables
.
push_back
(
std
::
make_pair
(
foo_key
,
foo_value
));
std
::
pair
<
string
,
VariableSpec
::
Format
>
baz_key
(
"baz"
,
VariableSpec
::
FORMAT_COLUMN_BLOCKED_ROW_MAJOR_MATRIX
);
std
::
pair
<
std
::
vector
<
size_t
>
,
AlignedArea
>
baz_value
(
{
2
,
8
,
4
},
AlignedArea
(
baz_data
.
area
()));
variables
.
push_back
(
std
::
make_pair
(
baz_key
,
baz_value
));
ArrayVariableStoreSpec
actual_spec
;
actual_spec
.
set_version
(
999
);
string
actual_data
=
"garbage to be overwritten"
;
TF_ASSERT_OK
(
ArrayVariableStoreBuilder
::
Build
(
variables
,
&
actual_spec
,
&
actual_data
));
if
(
false
)
{
// Rewrite the test data.
TF_CHECK_OK
(
tensorflow
::
WriteTextProto
(
tensorflow
::
Env
::
Default
(),
kLocalSpecPath
,
actual_spec
));
TF_CHECK_OK
(
tensorflow
::
WriteStringToFile
(
tensorflow
::
Env
::
Default
(),
kLocalDataPath
,
actual_data
));
}
else
{
// Compare to the test data.
ArrayVariableStoreSpec
expected_spec
;
string
expected_data
;
TF_CHECK_OK
(
tensorflow
::
ReadTextProto
(
tensorflow
::
Env
::
Default
(),
kExpectedSpecPath
,
&
expected_spec
));
TF_CHECK_OK
(
tensorflow
::
ReadFileToString
(
tensorflow
::
Env
::
Default
(),
kExpectedDataPath
,
&
expected_data
));
EXPECT_THAT
(
actual_spec
,
test
::
EqualsProto
(
expected_spec
));
EXPECT_EQ
(
actual_data
,
expected_data
);
}
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/array_variable_store_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/array_variable_store.h"
#include <string.h>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/runtime.pb.h"
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/file_array_variable_store.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/mmap_array_variable_store.h"
#include "dragnn/runtime/test/helpers.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
template
<
class
T
>
void
ExpectBlockedData
(
BlockedMatrix
<
T
>
matrix
,
const
std
::
vector
<
std
::
vector
<
T
>>
&
data
)
{
EXPECT_EQ
(
matrix
.
num_vectors
(),
data
.
size
());
// The indices don't really have semantic names, so we just use `i` and `j`.
// See BlockedMatrixFormat for details.
for
(
int
i
=
0
;
i
<
matrix
.
num_vectors
();
++
i
)
{
EXPECT_EQ
(
matrix
.
block_size
(),
data
[
i
].
size
());
for
(
int
j
=
0
;
j
<
data
[
i
].
size
();
++
j
)
{
EXPECT_EQ
(
matrix
.
vector
(
i
)[
j
],
data
[
i
][
j
]);
}
}
}
// Returns an ArrayVariableStoreSpec parsed from the |text|.
ArrayVariableStoreSpec
MakeSpec
(
const
string
&
text
)
{
ArrayVariableStoreSpec
spec
;
CHECK
(
TextFormat
::
ParseFromString
(
text
,
&
spec
));
return
spec
;
}
// Returns an ArrayVariableStoreSpec that has proper top-level settings and
// whose variables are parsed from the |variables_text|.
ArrayVariableStoreSpec
MakeSpecWithVariables
(
const
string
&
variables_text
)
{
return
MakeSpec
(
tensorflow
::
strings
::
StrCat
(
"version: 0 alignment_bytes: "
,
internal
::
kAlignmentBytes
,
" is_little_endian: "
,
tensorflow
::
port
::
kLittleEndian
,
" "
,
variables_text
));
}
// Tests that kLittleEndian actually means little-endian.
TEST
(
ArrayVariableStoreTest
,
EndianDetection
)
{
static_assert
(
sizeof
(
uint32
)
==
4
*
sizeof
(
uint8
),
"Unexpected int sizes"
);
const
uint32
foo
=
0xdeadbeef
;
uint8
foo_bytes
[
4
];
memcpy
(
foo_bytes
,
&
foo
,
4
*
sizeof
(
uint8
));
if
(
tensorflow
::
port
::
kLittleEndian
)
{
EXPECT_EQ
(
foo_bytes
[
3
],
0xde
);
EXPECT_EQ
(
foo_bytes
[
2
],
0xad
);
EXPECT_EQ
(
foo_bytes
[
1
],
0xbe
);
EXPECT_EQ
(
foo_bytes
[
0
],
0xef
);
}
else
{
EXPECT_EQ
(
foo_bytes
[
0
],
0xde
);
EXPECT_EQ
(
foo_bytes
[
1
],
0xad
);
EXPECT_EQ
(
foo_bytes
[
2
],
0xbe
);
EXPECT_EQ
(
foo_bytes
[
3
],
0xef
);
}
}
// Tests that the store checks for missing fields.
TEST
(
ArrayVariableStoreTest
,
MissingRequiredField
)
{
for
(
const
string
kSpec
:
{
"version: 0 alignment_bytes: 0"
,
"version: 0 is_little_endian: true"
,
"alignment_bytes: 0 is_little_endian: true"
})
{
ArrayVariableStore
store
;
EXPECT_THAT
(
store
.
Reset
(
MakeSpec
(
kSpec
),
AlignedView
()),
test
::
IsErrorWithSubstr
(
"ArrayVariableStoreSpec is missing a required field"
));
}
}
// Tests that the store checks for a matching version number.
TEST
(
ArrayVariableStoreTest
,
VersionMismatch
)
{
const
string
kSpec
=
"version: 999 alignment_bytes: 0 is_little_endian: true"
;
ArrayVariableStore
store
;
EXPECT_THAT
(
store
.
Reset
(
MakeSpec
(
kSpec
),
AlignedView
()),
test
::
IsErrorWithSubstr
(
"ArrayVariableStoreSpec.version (999) "
"does not match the binary (0)"
));
}
// Tests that the store checks for a matching alignment requirement.
TEST
(
ArrayVariableStoreTest
,
AlignmentMismatch
)
{
const
string
kSpec
=
"version: 0 alignment_bytes: 1 is_little_endian: true"
;
ArrayVariableStore
store
;
EXPECT_THAT
(
store
.
Reset
(
MakeSpec
(
kSpec
),
AlignedView
()),
test
::
IsErrorWithSubstr
(
tensorflow
::
strings
::
StrCat
(
"ArrayVariableStoreSpec.alignment_bytes (1) does not match "
"the binary ("
,
internal
::
kAlignmentBytes
,
")"
)));
}
// Tests that the store checks for matching endian-ness.
TEST
(
ArrayVariableStoreTest
,
EndiannessMismatch
)
{
const
string
kSpec
=
tensorflow
::
strings
::
StrCat
(
"version: 0 alignment_bytes: "
,
internal
::
kAlignmentBytes
,
" is_little_endian: "
,
!
tensorflow
::
port
::
kLittleEndian
);
ArrayVariableStore
store
;
EXPECT_THAT
(
store
.
Reset
(
MakeSpec
(
kSpec
),
AlignedView
()),
test
::
IsErrorWithSubstr
(
tensorflow
::
strings
::
StrCat
(
"ArrayVariableStoreSpec.is_little_endian ("
,
!
tensorflow
::
port
::
kLittleEndian
,
") does not match the binary ("
,
tensorflow
::
port
::
kLittleEndian
,
")"
)));
}
// Tests that the store rejects FORMAT_UNKNOWN variables.
TEST
(
ArrayVariableStoreTest
,
RejectFormatUnknown
)
{
const
string
kVariables
=
"variable { format: FORMAT_UNKNOWN }"
;
ArrayVariableStore
store
;
EXPECT_THAT
(
store
.
Reset
(
MakeSpecWithVariables
(
kVariables
),
AlignedView
()),
test
::
IsErrorWithSubstr
(
"Unknown variable format"
));
}
// Tests that the store rejects FORMAT_FLAT variables with too few sub-views.
TEST
(
ArrayVariableStoreTest
,
TooFewViewsForFlatVariable
)
{
const
string
kVariables
=
"variable { format: FORMAT_FLAT num_views: 0 }"
;
ArrayVariableStore
store
;
EXPECT_THAT
(
store
.
Reset
(
MakeSpecWithVariables
(
kVariables
),
AlignedView
()),
test
::
IsErrorWithSubstr
(
"Flat variables must have 1 view"
));
}
// Tests that the store rejects FORMAT_FLAT variables with too many sub-views.
TEST
(
ArrayVariableStoreTest
,
TooManyViewsForFlatVariable
)
{
const
string
kVariables
=
"variable { format: FORMAT_FLAT num_views: 2 }"
;
ArrayVariableStore
store
;
EXPECT_THAT
(
store
.
Reset
(
MakeSpecWithVariables
(
kVariables
),
AlignedView
()),
test
::
IsErrorWithSubstr
(
"Flat variables must have 1 view"
));
}
// Tests that the store accepts FORMAT_ROW_MAJOR_MATRIX variables with one
// sub-view.
TEST
(
ArrayVariableStoreTest
,
MatrixWithOneRow
)
{
const
string
kVariables
=
"variable { format: FORMAT_ROW_MAJOR_MATRIX num_views: 1 view_size: 0 }"
;
ArrayVariableStore
store
;
TF_EXPECT_OK
(
store
.
Reset
(
MakeSpecWithVariables
(
kVariables
),
AlignedView
()));
}
// Tests that the store rejects variables that overrun the main byte array.
TEST
(
ArrayVariableStoreTest
,
VariableOverrunsMainByteArray
)
{
const
string
kVariables
=
"variable { format: FORMAT_FLAT num_views: 1 view_size: 1024 }"
;
AlignedView
data
;
TF_ASSERT_OK
(
data
.
Reset
(
nullptr
,
1023
));
ArrayVariableStore
store
;
EXPECT_THAT
(
store
.
Reset
(
MakeSpecWithVariables
(
kVariables
),
data
),
test
::
IsErrorWithSubstr
(
"Variable would overrun main byte array"
));
}
// Tests that the store rejects duplicate variables.
TEST
(
ArrayVariableStoreTest
,
DuplicateVariables
)
{
const
string
kVariables
=
R"(
variable { name: 'x' format: FORMAT_FLAT num_views: 1 view_size: 1024 }
variable { name: 'y' format: FORMAT_FLAT num_views: 1 view_size: 2048 }
variable { name: 'x' format: FORMAT_FLAT num_views: 1 view_size: 4096 }
)"
;
AlignedView
data
;
TF_ASSERT_OK
(
data
.
Reset
(
nullptr
,
1
<<
20
));
// 1MB
ArrayVariableStore
store
;
EXPECT_THAT
(
store
.
Reset
(
MakeSpecWithVariables
(
kVariables
),
data
),
test
::
IsErrorWithSubstr
(
"Duplicate variable"
));
}
// Tests that the store rejects sets of variables that do not completely cover
// the main byte array.
TEST
(
ArrayVariableStoreTest
,
LeftoverBytesInMainByteArray
)
{
const
string
kVariables
=
R"(
variable { name: 'x' format: FORMAT_FLAT num_views: 1 view_size: 1024 }
variable { name: 'y' format: FORMAT_FLAT num_views: 1 view_size: 2048 }
variable { name: 'z' format: FORMAT_FLAT num_views: 1 view_size: 4096 }
)"
;
AlignedView
data
;
TF_ASSERT_OK
(
data
.
Reset
(
nullptr
,
1
<<
20
));
// 1MB
ArrayVariableStore
store
;
EXPECT_THAT
(
store
.
Reset
(
MakeSpecWithVariables
(
kVariables
),
data
),
test
::
IsErrorWithSubstr
(
"Variables do not completely cover main byte array"
));
}
// The fast matrix-vector routines do not support padding.
TEST
(
ArrayVariableStoreTest
,
PaddingInBlockedMatrix
)
{
const
string
kVariables
=
R"(
variable {
name: "baz"
format: FORMAT_COLUMN_BLOCKED_ROW_MAJOR_MATRIX
num_views: 4
view_size: 16
dimension: 2
dimension: 4
dimension: 2
}
)"
;
AlignedView
data
;
TF_ASSERT_OK
(
data
.
Reset
(
nullptr
,
1
<<
20
));
// 1MB
ArrayVariableStore
store
;
EXPECT_THAT
(
store
.
Reset
(
MakeSpecWithVariables
(
kVariables
),
data
),
test
::
IsErrorWithSubstr
(
"Currently, fast matrix-vector operations do not support "
"padded blocked matrices"
));
}
// Tests that the store cannot retrieve variables when it is uninitialized.
TEST
(
ArrayVariableStoreTest
,
LookupWhenUninitialized
)
{
ArrayVariableStore
store
;
Vector
<
float
>
vector
;
EXPECT_THAT
(
store
.
Lookup
(
"foo"
,
&
vector
),
test
::
IsErrorWithSubstr
(
"ArrayVariableStore not initialized"
));
}
// Tests that the store can use an empty byte array when there are no variables.
TEST
(
ArrayVariableStoreTest
,
EmptyByteArrayWorksIfNoVariables
)
{
ArrayVariableStore
store
;
TF_EXPECT_OK
(
store
.
Reset
(
MakeSpecWithVariables
(
""
),
AlignedView
()));
// The store contains nothing.
Vector
<
float
>
vector
;
EXPECT_THAT
(
store
.
Lookup
(
"foo"
,
&
vector
),
test
::
IsErrorWithSubstr
(
"ArrayVariableStore has no variable with name "
"'foo' and format FORMAT_FLAT"
));
}
// Tests that the store fails if it is closed before it has been initialized.
TEST
(
ArrayVariableStoreTest
,
CloseBeforeReset
)
{
ArrayVariableStore
store
;
EXPECT_THAT
(
store
.
Close
(),
test
::
IsErrorWithSubstr
(
"ArrayVariableStore not initialized"
));
}
// Tests that the store can be closed (once) after it has been initialized.
TEST
(
ArrayVariableStoreTest
,
CloseAfterReset
)
{
ArrayVariableStore
store
;
TF_ASSERT_OK
(
store
.
Reset
(
MakeSpecWithVariables
(
""
),
AlignedView
()));
TF_EXPECT_OK
(
store
.
Close
());
// Closing twice is still an error.
EXPECT_THAT
(
store
.
Close
(),
test
::
IsErrorWithSubstr
(
"ArrayVariableStore not initialized"
));
}
// Templated on an ArrayVariableStore subclass.
template
<
class
Subclass
>
class
ArrayVariableStoreSubclassTest
:
public
::
testing
::
Test
{};
typedef
::
testing
::
Types
<
FileArrayVariableStore
,
MmapArrayVariableStore
>
Subclasses
;
TYPED_TEST_CASE
(
ArrayVariableStoreSubclassTest
,
Subclasses
);
// Tests that the store fails to load a non-existent file.
TYPED_TEST
(
ArrayVariableStoreSubclassTest
,
NonExistentFile
)
{
// Paths to the spec and data produced by array_variable_store_builder_test.
const
string
kDataPath
=
tensorflow
::
io
::
JoinPath
(
test
::
GetTestDataPrefix
(),
"dragnn/runtime/testdata/non_existent_file"
);
TypeParam
store
;
EXPECT_THAT
(
store
.
Reset
(
MakeSpecWithVariables
(
""
),
kDataPath
),
test
::
IsErrorWithSubstr
(
""
));
}
// Tests that the store can load an empty file if there are no variables.
TYPED_TEST
(
ArrayVariableStoreSubclassTest
,
EmptyFile
)
{
// Paths to the spec and data produced by array_variable_store_builder_test.
const
string
kDataPath
=
tensorflow
::
io
::
JoinPath
(
test
::
GetTestDataPrefix
(),
"dragnn/runtime/testdata/empty_file"
);
TypeParam
store
;
TF_ASSERT_OK
(
store
.
Reset
(
MakeSpecWithVariables
(
""
),
kDataPath
));
Vector
<
float
>
vector
;
Matrix
<
float
>
row_major_matrix
;
EXPECT_THAT
(
store
.
Lookup
(
"foo"
,
&
vector
),
test
::
IsErrorWithSubstr
(
"ArrayVariableStore has no variable with "
"name 'foo' and format FORMAT_FLAT"
));
EXPECT_THAT
(
store
.
Lookup
(
"bar"
,
&
row_major_matrix
),
test
::
IsErrorWithSubstr
(
"ArrayVariableStore has no variable with name "
"'bar' and format FORMAT_ROW_MAJOR_MATRIX"
));
}
// Tests that the store, when loading a pre-built byte array, produces the same
// variables that the builder converted.
TYPED_TEST
(
ArrayVariableStoreSubclassTest
,
RegressionTest
)
{
// Paths to the spec and data produced by array_variable_store_builder_test.
const
string
kSpecPath
=
tensorflow
::
io
::
JoinPath
(
test
::
GetTestDataPrefix
(),
"dragnn/runtime/testdata/array_variable_store_spec"
);
const
string
kDataPath
=
tensorflow
::
io
::
JoinPath
(
test
::
GetTestDataPrefix
(),
"dragnn/runtime/testdata/array_variable_store_data"
);
ArrayVariableStoreSpec
spec
;
TF_CHECK_OK
(
tensorflow
::
ReadTextProto
(
tensorflow
::
Env
::
Default
(),
kSpecPath
,
&
spec
));
TypeParam
store
;
TF_ASSERT_OK
(
store
.
Reset
(
spec
,
kDataPath
));
Matrix
<
float
>
foo
;
TF_ASSERT_OK
(
store
.
Lookup
(
"foo"
,
&
foo
));
// NB: These assertions must be kept in sync with the variables defined in
// array_variable_store_builder_test.cc.
ExpectMatrix
(
foo
,
{{
0.0
,
0.5
,
1.0
},
//
{
1.5
,
2.0
,
2.5
},
//
{
3.0
,
3.5
,
4.0
},
//
{
4.5
,
5.0
,
5.5
}});
// Blocked formats.
BlockedMatrix
<
double
>
baz
;
TF_ASSERT_OK
(
store
.
Lookup
(
"baz"
,
&
baz
));
EXPECT_EQ
(
baz
.
num_rows
(),
2
);
EXPECT_EQ
(
baz
.
num_columns
(),
8
);
EXPECT_EQ
(
baz
.
block_size
(),
4
);
ExpectBlockedData
(
baz
,
{{
1.0
,
2.0
,
2.0
,
2.0
},
//
{
3.0
,
4.0
,
4.0
,
4.0
},
//
{
5.0
,
6.0
,
6.0
,
6.0
},
//
{
7.0
,
8.0
,
8.0
,
8.0
}});
// Try versions of "foo" and "baz" with the wrong format.
Vector
<
float
>
vector
;
Matrix
<
float
>
row_major_matrix
;
EXPECT_THAT
(
store
.
Lookup
(
"foo"
,
&
vector
),
test
::
IsErrorWithSubstr
(
"ArrayVariableStore has no variable with "
"name 'foo' and format FORMAT_FLAT"
));
EXPECT_THAT
(
store
.
Lookup
(
"baz"
,
&
vector
),
test
::
IsErrorWithSubstr
(
"ArrayVariableStore has no variable with "
"name 'baz' and format FORMAT_FLAT"
));
EXPECT_THAT
(
store
.
Lookup
(
"baz"
,
&
row_major_matrix
),
test
::
IsErrorWithSubstr
(
"ArrayVariableStore has no variable with name "
"'baz' and format FORMAT_ROW_MAJOR_MATRIX"
));
// Try totally unknown variables.
EXPECT_THAT
(
store
.
Lookup
(
"missing"
,
&
vector
),
test
::
IsErrorWithSubstr
(
"ArrayVariableStore has no variable with "
"name 'missing' and format FORMAT_FLAT"
));
EXPECT_THAT
(
store
.
Lookup
(
"missing"
,
&
row_major_matrix
),
test
::
IsErrorWithSubstr
(
"ArrayVariableStore has no variable with name "
"'missing' and format FORMAT_ROW_MAJOR_MATRIX"
));
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/attributes.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/attributes.h"
#include <set>
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/platform/logging.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
tensorflow
::
Status
Attributes
::
Reset
(
const
tensorflow
::
protobuf
::
Map
<
string
,
string
>
&
mapping
)
{
// First pass: Parse each value in the |mapping|.
for
(
const
auto
&
name_value
:
mapping
)
{
const
string
&
name
=
name_value
.
first
;
const
string
&
value
=
name_value
.
second
;
const
auto
it
=
attributes_
.
find
(
name
);
if
(
it
==
attributes_
.
end
())
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Unknown attribute: "
,
name
);
}
TF_RETURN_IF_ERROR
(
it
->
second
->
Parse
(
value
));
}
// Second pass: Look for missing mandatory attributes.
std
::
set
<
string
>
missing_mandatory_attributes
;
for
(
const
auto
&
it
:
attributes_
)
{
const
string
&
name
=
it
.
first
;
Attribute
*
attribute
=
it
.
second
;
if
(
!
attribute
->
IsMandatory
())
continue
;
if
(
mapping
.
find
(
name
)
==
mapping
.
end
())
{
missing_mandatory_attributes
.
insert
(
name
);
}
}
if
(
!
missing_mandatory_attributes
.
empty
())
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Missing mandatory attributes: "
,
tensorflow
::
str_util
::
Join
(
missing_mandatory_attributes
,
" "
));
}
return
tensorflow
::
Status
::
OK
();
}
void
Attributes
::
Register
(
const
string
&
name
,
Attribute
*
attribute
)
{
const
bool
unique
=
attributes_
.
emplace
(
name
,
attribute
).
second
;
DCHECK
(
unique
)
<<
"Duplicate attribute '"
<<
name
<<
"'"
;
}
tensorflow
::
Status
Attributes
::
ParseValue
(
const
string
&
str
,
string
*
value
)
{
*
value
=
str
;
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
Attributes
::
ParseValue
(
const
string
&
str
,
bool
*
value
)
{
const
string
lowercased_str
=
tensorflow
::
str_util
::
Lowercase
(
str
);
if
(
lowercased_str
!=
"true"
&&
lowercased_str
!=
"false"
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Attribute can't be parsed as bool: "
,
str
);
}
*
value
=
lowercased_str
==
"true"
;
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
Attributes
::
ParseValue
(
const
string
&
str
,
int32
*
value
)
{
if
(
!
tensorflow
::
strings
::
safe_strto32
(
str
,
value
))
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Attribute can't be parsed as int32: "
,
str
);
}
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
Attributes
::
ParseValue
(
const
string
&
str
,
int64
*
value
)
{
if
(
!
tensorflow
::
strings
::
safe_strto64
(
str
,
value
))
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Attribute can't be parsed as int64: "
,
str
);
}
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
Attributes
::
ParseValue
(
const
string
&
str
,
size_t
*
value
)
{
int64
signed_value
=
0
;
if
(
!
tensorflow
::
strings
::
safe_strto64
(
str
,
&
signed_value
)
||
signed_value
<
0
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Attribute can't be parsed as size_t: "
,
str
);
}
*
value
=
signed_value
;
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
Attributes
::
ParseValue
(
const
string
&
str
,
float
*
value
)
{
if
(
!
tensorflow
::
strings
::
safe_strtof
(
str
.
c_str
(),
value
))
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Attribute can't be parsed as float: "
,
str
);
}
return
tensorflow
::
Status
::
OK
();
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/attributes.h
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Utils for parsing configuration attributes from (name,value) string pairs as
// typed values. Intended for parsing RegisteredModuleSpec.parameters, similar
// to get_attrs_with_defaults() in network_units.py. Example usage:
//
// // Create a subclass of Attributes.
// struct MyComponentAttributes : public Attributes {
// // Mandatory attribute with type and name. The "this" allows the attribute
// // to register itself in its container---i.e., MyComponentAttributes.
// Mandatory<float> coefficient{"coefficient", this};
//
// // Optional attributes with type, name, and default value.
// Optional<bool> ignore_case{"ignore_case", true, this};
// Optional<std::vector<int32>> layer_sizes{"layer_sizes", {1, 2, 3}, this};
//
// // Ignored attribute, which does not parse any value.
// Ignored dropout_keep_prob{"dropout_keep_prob", this};
// };
//
// // Initialize an instance of the subclass from a string-to-string mapping.
// RegisteredModuleSpec spec;
// MyComponentAttributes attributes;
// TF_RETURN_IF_ERROR(attributes.Reset(spec.parameters()));
//
// // Access the attributes as accessors.
// bool ignore_case = attributes.ignore_case();
// float coefficient = attributes.coefficient();
// const std::vector<int32> &layer_sizes = attributes.layer_sizes();
//
// See the unit test for additional usage examples.
//
// TODO(googleuser): Build typed attributes into the RegisteredModuleSpec and
// get rid of this module.
#ifndef DRAGNN_RUNTIME_ATTRIBUTES_H_
#define DRAGNN_RUNTIME_ATTRIBUTES_H_
#include <functional>
#include <map>
#include <string>
#include <vector>
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/protobuf.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// Base class for sets of attributes. Use as indicated in the file comment.
class
Attributes
{
public:
// Untyped mapping from which typed attributes are parsed.
using
Mapping
=
tensorflow
::
protobuf
::
Map
<
string
,
string
>
;
// Forbids copying, which would invalidate the pointers in |attributes_|.
Attributes
(
const
Attributes
&
that
)
=
delete
;
Attributes
&
operator
=
(
const
Attributes
&
that
)
=
delete
;
// Parses registered attributes from the name-to-value |mapping|. On error,
// returns non-OK. Errors include unknown names in |mapping|, string-to-value
// parsing failures, and missing mandatory attributes.
tensorflow
::
Status
Reset
(
const
Mapping
&
mapping
);
protected:
// Implementations of the supported kinds of attributes, defined below.
class
Ignored
;
template
<
class
T
>
class
Optional
;
template
<
class
T
>
class
Mandatory
;
// Forbids lifecycle management except via subclasses.
Attributes
()
=
default
;
virtual
~
Attributes
()
=
default
;
private:
// Base class for an individual attribute, defined below.
class
Attribute
;
// Registers the |attribute| with the |name|, which must be unique.
void
Register
(
const
string
&
name
,
Attribute
*
attribute
);
// Parses the string |str| into the |value| object.
static
tensorflow
::
Status
ParseValue
(
const
string
&
str
,
string
*
value
);
static
tensorflow
::
Status
ParseValue
(
const
string
&
str
,
bool
*
value
);
static
tensorflow
::
Status
ParseValue
(
const
string
&
str
,
int32
*
value
);
static
tensorflow
::
Status
ParseValue
(
const
string
&
str
,
int64
*
value
);
static
tensorflow
::
Status
ParseValue
(
const
string
&
str
,
size_t
*
value
);
static
tensorflow
::
Status
ParseValue
(
const
string
&
str
,
float
*
value
);
template
<
class
Element
>
static
tensorflow
::
Status
ParseValue
(
const
string
&
str
,
std
::
vector
<
Element
>
*
value
);
// Registered attributes, keyed by name.
std
::
map
<
string
,
Attribute
*>
attributes_
;
};
// Implementation details below.
// Base class for individual attributes.
class
Attributes
::
Attribute
{
public:
Attribute
()
=
default
;
Attribute
(
const
Attribute
&
that
)
=
delete
;
Attribute
&
operator
=
(
const
Attribute
&
that
)
=
delete
;
virtual
~
Attribute
()
=
default
;
// Parses the |value| string into a typed object. On error, returns non-OK.
virtual
tensorflow
::
Status
Parse
(
const
string
&
value
)
=
0
;
// Returns true if this is a mandatory attribute. Defaults to optional.
virtual
bool
IsMandatory
()
const
{
return
false
;
}
};
// Implements an ignored attribute.
class
Attributes
::
Ignored
:
public
Attribute
{
public:
// Registers this in the |attributes| with the |name|.
Ignored
(
const
string
&
name
,
Attributes
*
attributes
)
{
attributes
->
Register
(
name
,
this
);
}
// Ignores the |value|.
tensorflow
::
Status
Parse
(
const
string
&
value
)
override
{
return
tensorflow
::
Status
::
OK
();
}
};
// Implements an optional attribute.
template
<
class
T
>
class
Attributes
::
Optional
:
public
Attribute
{
public:
// Registers this in the |attributes| with the |name| and |default_value|.
Optional
(
const
string
&
name
,
const
T
&
default_value
,
Attributes
*
attributes
)
:
value_
(
default_value
)
{
attributes
->
Register
(
name
,
this
);
}
// Parses the |value| into the |value_|.
tensorflow
::
Status
Parse
(
const
string
&
value
)
override
{
return
ParseValue
(
value
,
&
value_
);
}
// Returns the parsed |value_|. Overloading operator() allows a struct member
// to be called like an accessor.
const
T
&
operator
()()
const
{
return
value_
;
}
private:
// The parsed value, or the default value if not explicitly specified.
T
value_
;
};
// Implements a mandatory attribute.
template
<
class
T
>
class
Attributes
::
Mandatory
:
public
Optional
<
T
>
{
public:
// Registers this in the |attributes| with the |name|.
Mandatory
(
const
string
&
name
,
Attributes
*
attributes
)
:
Optional
<
T
>
(
name
,
T
(),
attributes
)
{}
// Returns true since this is mandatory.
bool
IsMandatory
()
const
override
{
return
true
;
}
private:
// The parsed value, or the default value if not explicitly specified.
T
value_
;
};
template
<
class
Element
>
tensorflow
::
Status
Attributes
::
ParseValue
(
const
string
&
str
,
std
::
vector
<
Element
>
*
value
)
{
value
->
clear
();
if
(
!
str
.
empty
())
{
for
(
const
string
&
element_str
:
tensorflow
::
str_util
::
Split
(
str
,
","
))
{
value
->
emplace_back
();
TF_RETURN_IF_ERROR
(
ParseValue
(
element_str
,
&
value
->
back
()));
}
}
return
tensorflow
::
Status
::
OK
();
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_ATTRIBUTES_H_
research/syntaxnet/dragnn/runtime/attributes_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/attributes.h"
#include <map>
#include <set>
#include <string>
#include <vector>
#include "dragnn/core/test/generic.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Returns the attribute mapping equivalent of the |std_map|.
Attributes
::
Mapping
MakeMapping
(
const
std
::
map
<
string
,
string
>
&
std_map
)
{
Attributes
::
Mapping
mapping
;
for
(
const
auto
&
it
:
std_map
)
mapping
[
it
.
first
]
=
it
.
second
;
return
mapping
;
}
// Returns a mapping with all attributes explicitly set.
Attributes
::
Mapping
GetFullySpecifiedMapping
()
{
return
MakeMapping
({{
"some_string"
,
"explicit"
},
{
"some_bool"
,
"true"
},
{
"some_int32"
,
"987"
},
{
"some_int64"
,
"654321"
},
{
"some_size_t"
,
"7777777"
},
{
"some_float"
,
"0.25"
},
{
"some_intvec"
,
"2,3,5,7,11,13"
},
{
"some_strvec"
,
"a,bc,def"
}});
}
// A set of optional attributes.
struct
OptionalAttributes
:
public
Attributes
{
Optional
<
string
>
some_string
{
"some_string"
,
"default"
,
this
};
Optional
<
bool
>
some_bool
{
"some_bool"
,
false
,
this
};
Optional
<
int32
>
some_int32
{
"some_int32"
,
32
,
this
};
Optional
<
int64
>
some_int64
{
"some_int64"
,
64
,
this
};
Optional
<
size_t
>
some_size_t
{
"some_size_t"
,
999
,
this
};
Optional
<
float
>
some_float
{
"some_float"
,
-
1.5
,
this
};
Optional
<
std
::
vector
<
int32
>>
some_intvec
{
"some_intvec"
,
{},
this
};
Optional
<
std
::
vector
<
string
>>
some_strvec
{
"some_strvec"
,
{
"x"
,
"y"
},
this
};
};
// Tests that attributes take their default values when they are not explicitly
// specified.
TEST
(
OptionalAttributesTest
,
Defaulted
)
{
Attributes
::
Mapping
mapping
;
OptionalAttributes
attributes
;
TF_ASSERT_OK
(
attributes
.
Reset
(
mapping
));
EXPECT_EQ
(
attributes
.
some_string
(),
"default"
);
EXPECT_FALSE
(
attributes
.
some_bool
());
EXPECT_EQ
(
attributes
.
some_int32
(),
32
);
EXPECT_EQ
(
attributes
.
some_int64
(),
64
);
EXPECT_EQ
(
attributes
.
some_size_t
(),
999
);
EXPECT_EQ
(
attributes
.
some_float
(),
-
1.5
);
EXPECT_EQ
(
attributes
.
some_intvec
(),
std
::
vector
<
int32
>
());
EXPECT_EQ
(
attributes
.
some_strvec
(),
std
::
vector
<
string
>
({
"x"
,
"y"
}));
}
// Tests that attributes can be overridden to explicitly-specified values.
TEST
(
OptionalAttributesTest
,
FullySpecified
)
{
OptionalAttributes
attributes
;
TF_ASSERT_OK
(
attributes
.
Reset
(
GetFullySpecifiedMapping
()));
EXPECT_EQ
(
attributes
.
some_string
(),
"explicit"
);
EXPECT_TRUE
(
attributes
.
some_bool
());
EXPECT_EQ
(
attributes
.
some_int32
(),
987
);
EXPECT_EQ
(
attributes
.
some_int64
(),
654321
);
EXPECT_EQ
(
attributes
.
some_size_t
(),
7777777
);
EXPECT_EQ
(
attributes
.
some_float
(),
0.25
);
EXPECT_EQ
(
attributes
.
some_intvec
(),
std
::
vector
<
int32
>
({
2
,
3
,
5
,
7
,
11
,
13
}));
EXPECT_EQ
(
attributes
.
some_strvec
(),
std
::
vector
<
string
>
({
"a"
,
"bc"
,
"def"
}));
}
// Tests that attribute parsing fails for an unknown name.
TEST
(
OptionalAttributesTest
,
UnknownName
)
{
const
Attributes
::
Mapping
mapping
=
MakeMapping
({{
"unknown"
,
"##BAD##"
}});
OptionalAttributes
attributes
;
EXPECT_THAT
(
attributes
.
Reset
(
mapping
),
test
::
IsErrorWithSubstr
(
"Unknown attribute"
));
}
// Tests that attribute parsing fails for malformed bool values.
TEST
(
OptionalAttributesTest
,
BadBool
)
{
for
(
const
string
&
value
:
{
" true"
,
"true "
,
"tr ue"
,
"arst"
,
"1"
,
"t"
,
"y"
,
"yes"
,
" false"
,
"false "
,
"fa lse"
,
"oien"
,
"0"
,
"f"
,
"n"
,
"no"
})
{
const
Attributes
::
Mapping
mapping
=
MakeMapping
({{
"some_bool"
,
value
}});
OptionalAttributes
attributes
;
EXPECT_THAT
(
attributes
.
Reset
(
mapping
),
test
::
IsErrorWithSubstr
(
"Attribute can't be parsed as bool"
));
}
}
// Tests that attribute parsing works for well-formed bool values.
TEST
(
OptionalAttributesTest
,
GoodBool
)
{
for
(
const
string
&
value
:
{
"true"
,
"TRUE"
,
"True"
,
"tRuE"
})
{
const
Attributes
::
Mapping
mapping
=
MakeMapping
({{
"some_bool"
,
value
}});
OptionalAttributes
attributes
;
TF_ASSERT_OK
(
attributes
.
Reset
(
mapping
));
EXPECT_TRUE
(
attributes
.
some_bool
());
}
for
(
const
string
&
value
:
{
"false"
,
"FALSE"
,
"False"
,
"fAlSe"
})
{
const
Attributes
::
Mapping
mapping
=
MakeMapping
({{
"some_bool"
,
value
}});
OptionalAttributes
attributes
;
TF_ASSERT_OK
(
attributes
.
Reset
(
mapping
));
EXPECT_FALSE
(
attributes
.
some_bool
());
}
}
// Tests that attribute parsing fails for malformed int32 values.
TEST
(
OptionalAttributesTest
,
BadInt32
)
{
for
(
const
string
&
value
:
{
"hello"
,
"true"
,
"1.0"
,
"inf"
,
"nan"
})
{
const
Attributes
::
Mapping
mapping
=
MakeMapping
({{
"some_int32"
,
value
}});
OptionalAttributes
attributes
;
EXPECT_THAT
(
attributes
.
Reset
(
mapping
),
test
::
IsErrorWithSubstr
(
"Attribute can't be parsed as int32"
));
}
}
// Tests that attribute parsing fails for malformed int64 values.
TEST
(
OptionalAttributesTest
,
BadInt64
)
{
for
(
const
string
&
value
:
{
"hello"
,
"true"
,
"1.0"
,
"inf"
,
"nan"
})
{
const
Attributes
::
Mapping
mapping
=
MakeMapping
({{
"some_int64"
,
value
}});
OptionalAttributes
attributes
;
EXPECT_THAT
(
attributes
.
Reset
(
mapping
),
test
::
IsErrorWithSubstr
(
"Attribute can't be parsed as int64"
));
}
}
// Tests that attribute parsing fails for malformed size_t values.
TEST
(
OptionalAttributesTest
,
BadSizeT
)
{
for
(
const
string
&
value
:
{
"hello"
,
"true"
,
"1.0"
,
"inf"
,
"nan"
,
"-1.0"
,
"-123"
})
{
const
Attributes
::
Mapping
mapping
=
MakeMapping
({{
"some_size_t"
,
value
}});
OptionalAttributes
attributes
;
EXPECT_THAT
(
attributes
.
Reset
(
mapping
),
test
::
IsErrorWithSubstr
(
"Attribute can't be parsed as size_t"
));
}
}
// Tests that attribute parsing fails for malformed floats.
TEST
(
OptionalAttributesTest
,
BadFloat
)
{
for
(
const
string
&
value
:
{
"hello"
,
"true"
})
{
const
Attributes
::
Mapping
mapping
=
MakeMapping
({{
"some_float"
,
value
}});
OptionalAttributes
attributes
;
EXPECT_THAT
(
attributes
.
Reset
(
mapping
),
test
::
IsErrorWithSubstr
(
"Attribute can't be parsed as float"
));
}
}
// Tests that attribute parsing fails for malformed std::vector<int32> values.
TEST
(
OptionalAttributesTest
,
BadIntVector
)
{
for
(
const
string
&
value
:
{
"hello"
,
"true"
,
"1.0"
,
"inf"
,
"nan"
,
"true,false"
,
"foo,bar,baz"
})
{
const
Attributes
::
Mapping
mapping
=
MakeMapping
({{
"some_intvec"
,
value
}});
OptionalAttributes
attributes
;
EXPECT_THAT
(
attributes
.
Reset
(
mapping
),
test
::
IsErrorWithSubstr
(
"Attribute can't be parsed as int32"
));
}
}
// A set of mandatory attributes.
struct
MandatoryAttributes
:
public
Attributes
{
Mandatory
<
string
>
some_string
{
"some_string"
,
this
};
Mandatory
<
bool
>
some_bool
{
"some_bool"
,
this
};
Mandatory
<
int32
>
some_int32
{
"some_int32"
,
this
};
Mandatory
<
int64
>
some_int64
{
"some_int64"
,
this
};
Mandatory
<
size_t
>
some_size_t
{
"some_size_t"
,
this
};
Mandatory
<
float
>
some_float
{
"some_float"
,
this
};
Mandatory
<
std
::
vector
<
int32
>>
some_intvec
{
"some_intvec"
,
this
};
Mandatory
<
std
::
vector
<
string
>>
some_strvec
{
"some_strvec"
,
this
};
};
// Tests that attribute parsing works when all mandatory attributes are
// explicitly specified.
TEST
(
MandatoryAttributesTest
,
FullySpecified
)
{
MandatoryAttributes
attributes
;
TF_ASSERT_OK
(
attributes
.
Reset
(
GetFullySpecifiedMapping
()));
EXPECT_EQ
(
attributes
.
some_string
(),
"explicit"
);
EXPECT_TRUE
(
attributes
.
some_bool
());
EXPECT_EQ
(
attributes
.
some_int32
(),
987
);
EXPECT_EQ
(
attributes
.
some_int64
(),
654321
);
EXPECT_EQ
(
attributes
.
some_size_t
(),
7777777
);
EXPECT_EQ
(
attributes
.
some_float
(),
0.25
);
EXPECT_EQ
(
attributes
.
some_intvec
(),
std
::
vector
<
int32
>
({
2
,
3
,
5
,
7
,
11
,
13
}));
EXPECT_EQ
(
attributes
.
some_strvec
(),
std
::
vector
<
string
>
({
"a"
,
"bc"
,
"def"
}));
}
// Tests that attribute parsing fails when even one mandatory attribute is not
// explicitly specified.
TEST
(
MandatoryAttributesTest
,
MissingAttribute
)
{
for
(
const
auto
&
it
:
GetFullySpecifiedMapping
())
{
const
string
&
name
=
it
.
first
;
Attributes
::
Mapping
mapping
=
GetFullySpecifiedMapping
();
CHECK_EQ
(
mapping
.
erase
(
name
),
1
);
MandatoryAttributes
attributes
;
EXPECT_THAT
(
attributes
.
Reset
(
mapping
),
test
::
IsErrorWithSubstr
(
"Missing mandatory attributes"
));
}
}
// A set of ignored attributes.
struct
IgnoredAttributes
:
public
Attributes
{
Ignored
foo
{
"foo"
,
this
};
Ignored
bar
{
"bar"
,
this
};
Ignored
baz
{
"baz"
,
this
};
};
// Tests that ignored attributes are not mandatory.
TEST
(
IgnoredAttributesTest
,
NotMandatory
)
{
const
Attributes
::
Mapping
mapping
;
IgnoredAttributes
attributes
;
TF_ASSERT_OK
(
attributes
.
Reset
(
mapping
));
}
// Tests that attribute parsing consumes ignored names.
TEST
(
IgnoredAttributesTest
,
IgnoredName
)
{
const
Attributes
::
Mapping
mapping
=
MakeMapping
({{
"foo"
,
"blah"
},
{
"bar"
,
"123"
},
{
"baz"
,
" "
}});
IgnoredAttributes
attributes
;
TF_ASSERT_OK
(
attributes
.
Reset
(
mapping
));
}
// Tests that attribute parsing still fails for unknown names.
TEST
(
IgnoredAttributesTest
,
UnknownName
)
{
const
Attributes
::
Mapping
mapping
=
MakeMapping
(
{{
"foo"
,
"blah"
},
{
"bar"
,
"123"
},
{
"baz"
,
" "
},
{
"unknown"
,
""
}});
IgnoredAttributes
attributes
;
EXPECT_THAT
(
attributes
.
Reset
(
mapping
),
test
::
IsErrorWithSubstr
(
"Unknown attribute"
));
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/biaffine_digraph_component.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include "dragnn/core/compute_session.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/math/eigen.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/network_unit.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Produces pairwise activations via a biaffine product between source and
// target token activations, as in the Dozat parser. This is the runtime
// version of the BiaffineDigraphNetwork, but is implemented as a Component
// instead of a NetworkUnit so it can control operand allocation.
class
BiaffineDigraphComponent
:
public
Component
{
public:
// Implements Component.
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
override
;
tensorflow
::
Status
Evaluate
(
SessionState
*
session_state
,
ComputeSession
*
compute_session
,
ComponentTrace
*
component_trace
)
const
override
;
bool
Supports
(
const
ComponentSpec
&
component_spec
,
const
string
&
normalized_builder_name
)
const
override
;
bool
PreferredTo
(
const
Component
&
other
)
const
override
{
return
false
;
}
private:
// Weights for computing source-target arc potentials.
Matrix
<
float
>
arc_weights_
;
// Weights for computing source-selection potentials.
Vector
<
float
>
source_weights_
;
// Weights and bias for root-target arc potentials.
Vector
<
float
>
root_weights_
;
float
root_bias_
=
0.0
;
// Source and target token activation inputs.
LayerHandle
<
float
>
sources_handle_
;
LayerHandle
<
float
>
targets_handle_
;
// Directed adjacency matrix output.
PairwiseLayerHandle
<
float
>
adjacency_handle_
;
// Handles for intermediate computations.
LocalMatrixHandle
<
float
>
target_product_handle_
;
};
bool
BiaffineDigraphComponent
::
Supports
(
const
ComponentSpec
&
component_spec
,
const
string
&
normalized_builder_name
)
const
{
const
string
network_unit
=
NetworkUnit
::
GetClassName
(
component_spec
);
return
(
normalized_builder_name
==
"BulkFeatureExtractorComponent"
||
normalized_builder_name
==
"BiaffineDigraphComponent"
)
&&
network_unit
==
"BiaffineDigraphNetwork"
;
}
// Finds the link named |name| in the |component_spec| and points the |handle|
// at the corresponding layer in the |network_state_manager|. The layer must
// also match the |required_dimension|. Returns non-OK on error.
tensorflow
::
Status
FindAndValidateLink
(
const
ComponentSpec
&
component_spec
,
const
NetworkStateManager
&
network_state_manager
,
const
string
&
name
,
size_t
required_dimension
,
LayerHandle
<
float
>
*
handle
)
{
const
LinkedFeatureChannel
*
link
=
nullptr
;
for
(
const
LinkedFeatureChannel
&
channel
:
component_spec
.
linked_feature
())
{
if
(
channel
.
name
()
==
name
)
{
link
=
&
channel
;
break
;
}
}
if
(
link
==
nullptr
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
component_spec
.
name
(),
": link '"
,
name
,
"' does not exist"
);
}
const
string
error_suffix
=
tensorflow
::
strings
::
StrCat
(
" in link { "
,
link
->
ShortDebugString
(),
" }"
);
if
(
link
->
embedding_dim
()
!=
-
1
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
component_spec
.
name
(),
": transformed links are forbidden"
,
error_suffix
);
}
if
(
link
->
size
()
!=
1
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
component_spec
.
name
(),
": multi-embedding links are forbidden"
,
error_suffix
);
}
if
(
link
->
source_component
()
==
component_spec
.
name
())
{
return
tensorflow
::
errors
::
InvalidArgument
(
component_spec
.
name
(),
": recurrent links are forbidden"
,
error_suffix
);
}
if
(
link
->
fml
()
!=
"input.focus"
||
link
->
source_translator
()
!=
"identity"
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
component_spec
.
name
(),
": non-trivial link translation is forbidden"
,
error_suffix
);
}
size_t
dimension
=
0
;
TF_RETURN_IF_ERROR
(
network_state_manager
.
LookupLayer
(
link
->
source_component
(),
link
->
source_layer
(),
&
dimension
,
handle
));
if
(
dimension
!=
required_dimension
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
component_spec
.
name
(),
": link '"
,
name
,
"' has dimension "
,
dimension
,
" instead of "
,
required_dimension
,
error_suffix
);
}
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
BiaffineDigraphComponent
::
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
{
TF_RETURN_IF_ERROR
(
variable_store
->
Lookup
(
tensorflow
::
strings
::
StrCat
(
component_spec
.
name
(),
"/weights_arc"
),
&
arc_weights_
));
const
size_t
source_dimension
=
arc_weights_
.
num_rows
();
const
size_t
target_dimension
=
arc_weights_
.
num_columns
();
TF_RETURN_IF_ERROR
(
variable_store
->
Lookup
(
tensorflow
::
strings
::
StrCat
(
component_spec
.
name
(),
"/weights_source"
),
&
source_weights_
));
if
(
source_weights_
.
size
()
!=
source_dimension
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
component_spec
.
name
(),
": dimension mismatch between weights_arc ["
,
source_dimension
,
","
,
target_dimension
,
"] and weights_source ["
,
source_weights_
.
size
(),
"]"
);
}
TF_RETURN_IF_ERROR
(
variable_store
->
Lookup
(
tensorflow
::
strings
::
StrCat
(
component_spec
.
name
(),
"/root_weights"
),
&
root_weights_
));
if
(
root_weights_
.
size
()
!=
target_dimension
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
component_spec
.
name
(),
": dimension mismatch between weights_arc ["
,
source_dimension
,
","
,
target_dimension
,
"] and root_weights ["
,
root_weights_
.
size
(),
"]"
);
}
Vector
<
float
>
root_bias
;
TF_RETURN_IF_ERROR
(
variable_store
->
Lookup
(
tensorflow
::
strings
::
StrCat
(
component_spec
.
name
(),
"/root_bias"
),
&
root_bias
));
if
(
root_bias
.
size
()
!=
1
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
component_spec
.
name
(),
": root_bias must be a singleton"
);
}
root_bias_
=
root_bias
[
0
];
if
(
component_spec
.
fixed_feature_size
()
!=
0
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
component_spec
.
name
(),
": fixed features are forbidden"
);
}
if
(
component_spec
.
linked_feature_size
()
!=
2
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
component_spec
.
name
(),
": two linked features are required"
);
}
TF_RETURN_IF_ERROR
(
FindAndValidateLink
(
component_spec
,
*
network_state_manager
,
"sources"
,
source_dimension
,
&
sources_handle_
));
TF_RETURN_IF_ERROR
(
FindAndValidateLink
(
component_spec
,
*
network_state_manager
,
"targets"
,
target_dimension
,
&
targets_handle_
));
TF_RETURN_IF_ERROR
(
network_state_manager
->
AddLayer
(
"adjacency"
,
1
,
&
adjacency_handle_
));
TF_RETURN_IF_ERROR
(
network_state_manager
->
AddLocal
(
source_dimension
,
&
target_product_handle_
));
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
BiaffineDigraphComponent
::
Evaluate
(
SessionState
*
session_state
,
ComputeSession
*
compute_session
,
ComponentTrace
*
component_trace
)
const
{
NetworkStates
&
network_states
=
session_state
->
network_states
;
// Infer the number of steps from the source and target activations.
EigenMatrixMap
<
float
>
sources
=
AsEigenMap
(
Matrix
<
float
>
(
network_states
.
GetLayer
(
sources_handle_
)));
EigenMatrixMap
<
float
>
targets
=
AsEigenMap
(
Matrix
<
float
>
(
network_states
.
GetLayer
(
targets_handle_
)));
const
size_t
num_steps
=
sources
.
rows
();
if
(
targets
.
rows
()
!=
num_steps
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"step count mismatch between sources ("
,
num_steps
,
") and targets ("
,
targets
.
rows
(),
")"
);
}
// Since this component has a pairwise layer, allocate steps in one shot.
network_states
.
AddSteps
(
num_steps
);
MutableEigenMatrixMap
<
float
>
adjacency
=
AsEigenMap
(
network_states
.
GetLayer
(
adjacency_handle_
));
MutableEigenMatrixMap
<
float
>
target_product
=
AsEigenMap
(
network_states
.
GetLocal
(
target_product_handle_
));
// First compute the adjacency matrix of combined arc and source scores.
// Note: .noalias() ensures that the RHS is assigned directly to the LHS;
// otherwise, Eigen may allocate a temp matrix to hold the result of the
// matmul on the RHS and then copy that to the LHS. See
// http://eigen.tuxfamily.org/dox/TopicLazyEvaluation.html
target_product
.
noalias
()
=
targets
*
AsEigenMap
(
arc_weights_
).
transpose
();
target_product
.
rowwise
()
+=
AsEigenMap
(
source_weights_
);
adjacency
.
noalias
()
=
target_product
*
sources
.
transpose
();
// Now overwrite the diagonal with root-selection scores.
// Note: .array() allows the scalar addition of |root_bias_| to broadcast
// across the diagonal. See
// https://eigen.tuxfamily.org/dox/group__TutorialArrayClass.html
adjacency
.
diagonal
().
noalias
()
=
AsEigenMap
(
root_weights_
)
*
targets
.
transpose
();
adjacency
.
diagonal
().
array
()
+=
root_bias_
;
return
tensorflow
::
Status
::
OK
();
}
DRAGNN_RUNTIME_REGISTER_COMPONENT
(
BiaffineDigraphComponent
);
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/biaffine_digraph_component_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <memory>
#include <string>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "syntaxnet/base.h"
#include <gmock/gmock.h>
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
using
::
testing
::
Return
;
constexpr
size_t
kNumSteps
=
33
;
constexpr
size_t
kSourceDim
=
44
;
constexpr
size_t
kTargetDim
=
55
;
constexpr
size_t
kBadDim
=
11
;
constexpr
float
kArcWeight
=
1.0
;
constexpr
float
kSourceWeight
=
2.0
;
constexpr
float
kRootWeight
=
4.0
;
constexpr
float
kRootBias
=
8.0
;
constexpr
float
kSourceValue
=
-
0.5
;
constexpr
float
kTargetValue
=
1.5
;
constexpr
char
kSourcesComponentName
[]
=
"sources"
;
constexpr
char
kTargetsComponentName
[]
=
"targets"
;
constexpr
char
kSourcesLayerName
[]
=
"sources"
;
constexpr
char
kTargetsLayerName
[]
=
"targets"
;
constexpr
char
kBadDimLayerName
[]
=
"bad"
;
// Configuration for the Run() method. This makes it easier for tests to
// manipulate breakages.
struct
RunConfig
{
// Number of steps in the preceding components.
size_t
sources_num_steps
=
kNumSteps
;
size_t
targets_num_steps
=
kNumSteps
;
// Dimensions of the variables.
size_t
weights_source_dim
=
kSourceDim
;
size_t
root_weights_dim
=
kTargetDim
;
size_t
root_bias_dim
=
1
;
};
class
BiaffineDigraphComponentTest
:
public
NetworkTestBase
{
protected:
BiaffineDigraphComponentTest
()
{
EXPECT_CALL
(
compute_session_
,
GetInputBatchCache
())
.
WillRepeatedly
(
Return
(
&
input_
));
}
// Returns a working spec.
static
ComponentSpec
MakeGoodSpec
()
{
ComponentSpec
component_spec
;
component_spec
.
set_name
(
kTestComponentName
);
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"bulk_component.BulkFeatureExtractorComponentBuilder"
);
component_spec
.
mutable_network_unit
()
->
set_registered_name
(
"biaffine_units.BiaffineDigraphNetwork"
);
for
(
const
string
&
name
:
{
kSourcesLayerName
,
kTargetsLayerName
})
{
LinkedFeatureChannel
*
link
=
component_spec
.
add_linked_feature
();
link
->
set_name
(
name
);
link
->
set_embedding_dim
(
-
1
);
link
->
set_size
(
1
);
link
->
set_source_component
(
name
);
link
->
set_source_layer
(
name
);
link
->
set_source_translator
(
"identity"
);
link
->
set_fml
(
"input.focus"
);
}
return
component_spec
;
}
// Creates a component, initializes it based on the |component_spec|, and
// evaluates it. On error, returns non-OK.
tensorflow
::
Status
Run
(
const
ComponentSpec
&
component_spec
,
const
RunConfig
&
config
=
RunConfig
())
{
AddComponent
(
kSourcesComponentName
);
AddLayer
(
kSourcesLayerName
,
kSourceDim
);
AddComponent
(
kTargetsComponentName
);
AddLayer
(
kTargetsLayerName
,
kTargetDim
);
AddLayer
(
kBadDimLayerName
,
kBadDim
);
AddComponent
(
kTestComponentName
);
AddMatrixVariable
(
tensorflow
::
strings
::
StrCat
(
kTestComponentName
,
"/weights_arc"
),
kSourceDim
,
kTargetDim
,
kArcWeight
);
AddVectorVariable
(
tensorflow
::
strings
::
StrCat
(
kTestComponentName
,
"/weights_source"
),
config
.
weights_source_dim
,
kSourceWeight
);
AddVectorVariable
(
tensorflow
::
strings
::
StrCat
(
kTestComponentName
,
"/root_weights"
),
config
.
root_weights_dim
,
kRootWeight
);
AddVectorVariable
(
tensorflow
::
strings
::
StrCat
(
kTestComponentName
,
"/root_bias"
),
config
.
root_bias_dim
,
kRootBias
);
TF_RETURN_IF_ERROR
(
Component
::
CreateOrError
(
"BiaffineDigraphComponent"
,
&
component_
));
TF_RETURN_IF_ERROR
(
component_
->
Initialize
(
component_spec
,
&
variable_store_
,
&
network_state_manager_
,
&
extension_manager_
));
network_states_
.
Reset
(
&
network_state_manager_
);
StartComponent
(
config
.
sources_num_steps
);
FillLayer
(
kSourcesComponentName
,
kSourcesLayerName
,
kSourceValue
);
StartComponent
(
config
.
targets_num_steps
);
FillLayer
(
kTargetsComponentName
,
kTargetsLayerName
,
kTargetValue
);
StartComponent
(
0
);
// BiaffineDigraphComponent will add steps
session_state_
.
extensions
.
Reset
(
&
extension_manager_
);
TF_RETURN_IF_ERROR
(
component_
->
Evaluate
(
&
session_state_
,
&
compute_session_
,
nullptr
));
adjacency_
=
GetPairwiseLayer
(
kTestComponentName
,
"adjacency"
);
return
tensorflow
::
Status
::
OK
();
}
InputBatchCache
input_
;
std
::
unique_ptr
<
Component
>
component_
;
Matrix
<
float
>
adjacency_
;
};
// Tests that the good spec works properly.
TEST_F
(
BiaffineDigraphComponentTest
,
GoodSpec
)
{
TF_ASSERT_OK
(
Run
(
MakeGoodSpec
()));
constexpr
float
kExpectedRootScore
=
kRootWeight
*
kTargetValue
*
kTargetDim
+
kRootBias
;
constexpr
float
kExpectedArcScore
=
kSourceDim
*
kSourceValue
*
kArcWeight
*
kTargetValue
*
kTargetDim
+
kSourceWeight
*
kSourceValue
*
kSourceDim
;
ASSERT_EQ
(
adjacency_
.
num_rows
(),
kNumSteps
);
ASSERT_EQ
(
adjacency_
.
num_columns
(),
kNumSteps
);
for
(
size_t
row
=
0
;
row
<
kNumSteps
;
++
row
)
{
for
(
size_t
column
=
0
;
column
<
kNumSteps
;
++
column
)
{
if
(
row
==
column
)
{
ASSERT_EQ
(
adjacency_
.
row
(
row
)[
column
],
kExpectedRootScore
);
}
else
{
ASSERT_EQ
(
adjacency_
.
row
(
row
)[
column
],
kExpectedArcScore
);
}
}
}
}
// Tests the set of supported components.
TEST_F
(
BiaffineDigraphComponentTest
,
Supports
)
{
ComponentSpec
component_spec
=
MakeGoodSpec
();
string
component_name
;
TF_ASSERT_OK
(
Component
::
Select
(
component_spec
,
&
component_name
));
EXPECT_EQ
(
component_name
,
"BiaffineDigraphComponent"
);
component_spec
.
mutable_network_unit
()
->
set_registered_name
(
"bad"
);
EXPECT_THAT
(
Component
::
Select
(
component_spec
,
&
component_name
),
test
::
IsErrorWithSubstr
(
"Could not find a best spec"
));
component_spec
=
MakeGoodSpec
();
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"BiaffineDigraphComponent"
);
TF_ASSERT_OK
(
Component
::
Select
(
component_spec
,
&
component_name
));
EXPECT_EQ
(
component_name
,
"BiaffineDigraphComponent"
);
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"bad"
);
EXPECT_THAT
(
Component
::
Select
(
component_spec
,
&
component_name
),
test
::
IsErrorWithSubstr
(
"Could not find a best spec"
));
}
// Tests that fixed features are rejected.
TEST_F
(
BiaffineDigraphComponentTest
,
FixedFeatures
)
{
ComponentSpec
component_spec
=
MakeGoodSpec
();
component_spec
.
add_fixed_feature
();
EXPECT_THAT
(
Run
(
component_spec
),
test
::
IsErrorWithSubstr
(
"fixed features are forbidden"
));
}
// Tests that too few linked features are rejected.
TEST_F
(
BiaffineDigraphComponentTest
,
TooFewLinkedFeatures
)
{
ComponentSpec
component_spec
=
MakeGoodSpec
();
component_spec
.
mutable_linked_feature
()
->
RemoveLast
();
EXPECT_THAT
(
Run
(
component_spec
),
test
::
IsErrorWithSubstr
(
"two linked features are required"
));
}
// Tests that too many linked features are rejected.
TEST_F
(
BiaffineDigraphComponentTest
,
TooManyLinkedFeatures
)
{
ComponentSpec
component_spec
=
MakeGoodSpec
();
component_spec
.
add_linked_feature
();
EXPECT_THAT
(
Run
(
component_spec
),
test
::
IsErrorWithSubstr
(
"two linked features are required"
));
}
// Tests that a spec with no "sources" link is rejected.
TEST_F
(
BiaffineDigraphComponentTest
,
MissingSources
)
{
ComponentSpec
component_spec
=
MakeGoodSpec
();
component_spec
.
mutable_linked_feature
(
0
)
->
set_name
(
"bad"
);
EXPECT_THAT
(
Run
(
component_spec
),
test
::
IsErrorWithSubstr
(
"link 'sources' does not exist"
));
}
// Tests that a spec with no "targets" link is rejected.
TEST_F
(
BiaffineDigraphComponentTest
,
MissingTargets
)
{
ComponentSpec
component_spec
=
MakeGoodSpec
();
component_spec
.
mutable_linked_feature
(
1
)
->
set_name
(
"bad"
);
EXPECT_THAT
(
Run
(
component_spec
),
test
::
IsErrorWithSubstr
(
"link 'targets' does not exist"
));
}
// Tests that a spec with transformed links is rejected.
TEST_F
(
BiaffineDigraphComponentTest
,
TransformedLink
)
{
ComponentSpec
component_spec
=
MakeGoodSpec
();
component_spec
.
mutable_linked_feature
(
1
)
->
set_embedding_dim
(
123
);
EXPECT_THAT
(
Run
(
component_spec
),
test
::
IsErrorWithSubstr
(
"transformed links are forbidden"
));
}
// Tests that a spec with multi-embedding links is rejected.
TEST_F
(
BiaffineDigraphComponentTest
,
MultiEmbeddingLink
)
{
ComponentSpec
component_spec
=
MakeGoodSpec
();
component_spec
.
mutable_linked_feature
(
1
)
->
set_size
(
2
);
EXPECT_THAT
(
Run
(
component_spec
),
test
::
IsErrorWithSubstr
(
"multi-embedding links are forbidden"
));
}
// Tests that a spec with recurrent links is rejected.
TEST_F
(
BiaffineDigraphComponentTest
,
RecurrentLink
)
{
ComponentSpec
component_spec
=
MakeGoodSpec
();
component_spec
.
mutable_linked_feature
(
1
)
->
set_source_component
(
kTestComponentName
);
EXPECT_THAT
(
Run
(
component_spec
),
test
::
IsErrorWithSubstr
(
"recurrent links are forbidden"
));
}
// Tests that a spec with improper FML is rejected.
TEST_F
(
BiaffineDigraphComponentTest
,
BadFML
)
{
ComponentSpec
component_spec
=
MakeGoodSpec
();
component_spec
.
mutable_linked_feature
(
1
)
->
set_fml
(
"bad"
);
EXPECT_THAT
(
Run
(
component_spec
),
test
::
IsErrorWithSubstr
(
"non-trivial link translation is forbidden"
));
}
// Tests that a spec with non-identity links is rejected.
TEST_F
(
BiaffineDigraphComponentTest
,
NonIdentityLink
)
{
ComponentSpec
component_spec
=
MakeGoodSpec
();
component_spec
.
mutable_linked_feature
(
1
)
->
set_source_translator
(
"bad"
);
EXPECT_THAT
(
Run
(
component_spec
),
test
::
IsErrorWithSubstr
(
"non-trivial link translation is forbidden"
));
}
// Tests that a link with the wrong dimension is rejected.
TEST_F
(
BiaffineDigraphComponentTest
,
WrongLinkDimension
)
{
ComponentSpec
component_spec
=
MakeGoodSpec
();
component_spec
.
mutable_linked_feature
(
1
)
->
set_source_layer
(
kBadDimLayerName
);
EXPECT_THAT
(
Run
(
component_spec
),
test
::
IsErrorWithSubstr
(
"link 'targets' has dimension 11 instead of 55"
));
}
// Tests that a mismatched weights_source dimension is rejected.
TEST_F
(
BiaffineDigraphComponentTest
,
WeightsSourceDimensionMismatch
)
{
RunConfig
config
;
config
.
weights_source_dim
=
999
;
EXPECT_THAT
(
Run
(
MakeGoodSpec
(),
config
),
test
::
IsErrorWithSubstr
(
"dimension mismatch between weights_arc "
"[44,55] and weights_source [999]"
));
}
// Tests that a mismatched root_weights dimension is rejected.
TEST_F
(
BiaffineDigraphComponentTest
,
RootWeightsDimensionMismatch
)
{
RunConfig
config
;
config
.
root_weights_dim
=
999
;
EXPECT_THAT
(
Run
(
MakeGoodSpec
(),
config
),
test
::
IsErrorWithSubstr
(
"dimension mismatch between weights_arc "
"[44,55] and root_weights [999]"
));
}
// Tests that a mismatched root_bias dimension is rejected.
TEST_F
(
BiaffineDigraphComponentTest
,
RootBiasDimensionMismatch
)
{
RunConfig
config
;
config
.
root_bias_dim
=
999
;
EXPECT_THAT
(
Run
(
MakeGoodSpec
(),
config
),
test
::
IsErrorWithSubstr
(
"root_bias must be a singleton"
));
}
// Tests that a mismatched number of steps is rejected.
TEST_F
(
BiaffineDigraphComponentTest
,
StepCountMismatch
)
{
RunConfig
config
;
config
.
targets_num_steps
=
999
;
EXPECT_THAT
(
Run
(
MakeGoodSpec
(),
config
),
test
::
IsErrorWithSubstr
(
"step count mismatch between sources (33) and targets (999)"
));
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/bulk_dynamic_component.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <memory>
#include <string>
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "dragnn/runtime/bulk_network_unit.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/network_unit_base.h"
#include "dragnn/runtime/transition_system_traits.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Network unit that allows us to make calls to NetworkUnitBase and extract
// features. We may want to provide more optimized versions of this class.
class
BulkFeatureExtractorNetwork
:
public
NetworkUnitBase
{
public:
// Returns true if this supports the |component_spec|. Requires:
// * A deterministic transition system, which can be advanced from the oracle.
// * No recurrent linked features (i.e. from this system).
static
bool
Supports
(
const
ComponentSpec
&
component_spec
);
// Implements NetworkUnit.
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
override
;
// Advances the |compute_session| through all oracle transitions and extracts
// fixed and linked embeddings, concatenates them into an input matrix stored
// in the NetworkStates in the |session_state|, and points the |inputs| at it.
// Also adds steps to the NetworkStates. On error, returns non-OK.
tensorflow
::
Status
EvaluateInputs
(
SessionState
*
session_state
,
ComputeSession
*
compute_session
,
Matrix
<
float
>
*
inputs
)
const
;
private:
// Implements NetworkUnit. Evaluate() is "final" to encourage inlining.
string
GetLogitsName
()
const
override
{
return
""
;
}
tensorflow
::
Status
Evaluate
(
size_t
step_index
,
SessionState
*
session_state
,
ComputeSession
*
compute_session
)
const
final
;
// Name of the containing component.
string
name_
;
// Concatenated input matrix.
LocalMatrixHandle
<
float
>
inputs_handle_
;
};
bool
BulkFeatureExtractorNetwork
::
Supports
(
const
ComponentSpec
&
component_spec
)
{
if
(
!
TransitionSystemTraits
(
component_spec
).
is_deterministic
)
return
false
;
// Forbid recurrent linked features.
for
(
const
LinkedFeatureChannel
&
channel
:
component_spec
.
linked_feature
())
{
if
(
channel
.
source_component
()
==
component_spec
.
name
())
return
false
;
}
return
true
;
}
tensorflow
::
Status
BulkFeatureExtractorNetwork
::
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
{
name_
=
component_spec
.
name
();
if
(
!
Supports
(
component_spec
))
{
return
tensorflow
::
errors
::
InvalidArgument
(
"BulkFeatureExtractorNetwork does not support component '"
,
name_
,
"'"
);
}
const
bool
use_concatenated_input
=
true
;
TF_RETURN_IF_ERROR
(
InitializeBase
(
use_concatenated_input
,
component_spec
,
variable_store
,
network_state_manager
,
extension_manager
));
return
network_state_manager
->
AddLocal
(
concatenated_input_dim
(),
&
inputs_handle_
);
}
tensorflow
::
Status
BulkFeatureExtractorNetwork
::
EvaluateInputs
(
SessionState
*
session_state
,
ComputeSession
*
compute_session
,
Matrix
<
float
>
*
inputs
)
const
{
// TODO(googleuser): Try the ComputeSession's bulk feature extraction API?
for
(
size_t
step_idx
=
0
;
!
compute_session
->
IsTerminal
(
name_
);
++
step_idx
)
{
session_state
->
network_states
.
AddStep
();
TF_RETURN_IF_ERROR
(
Evaluate
(
step_idx
,
session_state
,
compute_session
));
compute_session
->
AdvanceFromOracle
(
name_
);
}
*
inputs
=
session_state
->
network_states
.
GetLocal
(
inputs_handle_
);
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
BulkFeatureExtractorNetwork
::
Evaluate
(
size_t
step_index
,
SessionState
*
session_state
,
ComputeSession
*
compute_session
)
const
{
Vector
<
float
>
input
;
TF_RETURN_IF_ERROR
(
EvaluateBase
(
session_state
,
compute_session
,
&
input
));
MutableMatrix
<
float
>
all_inputs
=
session_state
->
network_states
.
GetLocal
(
inputs_handle_
);
// TODO(googleuser): Punch a hole in EvaluateBase so it writes directly to
// all_inputs.row(step_index).
//
// In the future, we could entirely eliminate copying, by providing a variant
// of LstmCellFunction::RunInputComputation that adds a partial vector of
// inputs, e.g. instead of RunInputComputation(x), we compute
//
// RunInputComputation(x[0:32]) + RunInputComputation(x[32:64])
//
// where perhaps x[0:32] points directly at a fixed word feature vector, and
// x[32:64] points directly at the previous layer's outputs (as a linked
// feature).
MutableVector
<
float
>
output
=
all_inputs
.
row
(
step_index
);
DCHECK_EQ
(
input
.
size
(),
output
.
size
());
// TODO(googleuser): Try memcpy() or a custom vectorized copy.
for
(
int
i
=
0
;
i
<
input
.
size
();
++
i
)
{
output
[
i
]
=
input
[
i
];
}
return
tensorflow
::
Status
::
OK
();
}
// Bulk version of a DynamicComponent---i.e., a component that was originally
// dynamic but can be automatically upgraded to a bulk version.
class
BulkDynamicComponent
:
public
Component
{
protected:
// Implements Component.
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
override
;
tensorflow
::
Status
Evaluate
(
SessionState
*
session_state
,
ComputeSession
*
compute_session
,
ComponentTrace
*
component_trace
)
const
override
;
bool
Supports
(
const
ComponentSpec
&
component_spec
,
const
string
&
normalized_builder_name
)
const
override
;
bool
PreferredTo
(
const
Component
&
other
)
const
override
{
return
true
;
}
private:
// Feature extractor that builds the input activation matrix.
BulkFeatureExtractorNetwork
bulk_feature_extractor_
;
// Network unit for bulk computation.
std
::
unique_ptr
<
BulkNetworkUnit
>
bulk_network_unit_
;
};
// In addition to the BulkFeatureExtractorNetwork requirements, the bulk LSTM
// requires no attention (the runtime doesn't support attention yet).
bool
BulkDynamicComponent
::
Supports
(
const
ComponentSpec
&
component_spec
,
const
string
&
normalized_builder_name
)
const
{
return
BulkFeatureExtractorNetwork
::
Supports
(
component_spec
)
&&
(
normalized_builder_name
==
"DynamicComponent"
||
normalized_builder_name
==
"BulkDynamicComponent"
)
&&
component_spec
.
attention_component
().
empty
();
}
tensorflow
::
Status
BulkDynamicComponent
::
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
{
TF_RETURN_IF_ERROR
(
bulk_feature_extractor_
.
Initialize
(
component_spec
,
variable_store
,
network_state_manager
,
extension_manager
));
TF_RETURN_IF_ERROR
(
BulkNetworkUnit
::
CreateOrError
(
BulkNetworkUnit
::
GetClassName
(
component_spec
),
&
bulk_network_unit_
));
TF_RETURN_IF_ERROR
(
bulk_network_unit_
->
Initialize
(
component_spec
,
variable_store
,
network_state_manager
,
extension_manager
));
return
bulk_network_unit_
->
ValidateInputDimension
(
bulk_feature_extractor_
.
concatenated_input_dim
());
}
tensorflow
::
Status
BulkDynamicComponent
::
Evaluate
(
SessionState
*
session_state
,
ComputeSession
*
compute_session
,
ComponentTrace
*
component_trace
)
const
{
Matrix
<
float
>
inputs
;
TF_RETURN_IF_ERROR
(
bulk_feature_extractor_
.
EvaluateInputs
(
session_state
,
compute_session
,
&
inputs
));
return
bulk_network_unit_
->
Evaluate
(
inputs
,
session_state
);
}
DRAGNN_RUNTIME_REGISTER_COMPONENT
(
BulkDynamicComponent
);
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/bulk_dynamic_component_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <memory>
#include <string>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/bulk_network_unit.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include <gmock/gmock.h>
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
using
::
testing
::
_
;
using
::
testing
::
Invoke
;
using
::
testing
::
Return
;
constexpr
size_t
kNumSteps
=
50
;
constexpr
size_t
kFixedDim
=
11
;
constexpr
size_t
kFixedVocabularySize
=
123
;
constexpr
float
kFixedValue
=
0.5
;
constexpr
size_t
kLinkedDim
=
13
;
constexpr
float
kLinkedValue
=
1.25
;
constexpr
char
kPreviousComponentName
[]
=
"previous_component"
;
constexpr
char
kPreviousLayerName
[]
=
"previous_layer"
;
constexpr
char
kOutputsName
[]
=
"outputs"
;
constexpr
size_t
kOutputsDim
=
kFixedDim
+
kLinkedDim
;
// Adds one to all inputs.
class
BulkAddOne
:
public
BulkNetworkUnit
{
public:
// Implements BulkNetworkUnit.
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
override
{
return
network_state_manager
->
AddLayer
(
kOutputsName
,
kOutputsDim
,
&
outputs_handle_
);
}
tensorflow
::
Status
ValidateInputDimension
(
size_t
dimension
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
string
GetLogitsName
()
const
override
{
return
""
;
}
tensorflow
::
Status
Evaluate
(
Matrix
<
float
>
inputs
,
SessionState
*
session_state
)
const
override
{
const
MutableMatrix
<
float
>
outputs
=
session_state
->
network_states
.
GetLayer
(
outputs_handle_
);
if
(
outputs
.
num_rows
()
!=
inputs
.
num_rows
()
||
outputs
.
num_columns
()
!=
inputs
.
num_columns
())
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Dimension mismatch"
);
}
for
(
size_t
row
=
0
;
row
<
inputs
.
num_rows
();
++
row
)
{
for
(
size_t
column
=
0
;
column
<
inputs
.
num_columns
();
++
column
)
{
outputs
.
row
(
row
)[
column
]
=
inputs
.
row
(
row
)[
column
]
+
1.0
;
}
}
return
tensorflow
::
Status
::
OK
();
}
private:
// Output outputs.
LayerHandle
<
float
>
outputs_handle_
;
};
DRAGNN_RUNTIME_REGISTER_BULK_NETWORK_UNIT
(
BulkAddOne
);
// A component that also prefers itself but is triggered on a certain backend.
// This can be used to cause a component selection conflict.
class
ImTheBest
:
public
Component
{
public:
// Implements Component.
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
Evaluate
(
SessionState
*
session_state
,
ComputeSession
*
compute_session
,
ComponentTrace
*
component_trace
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
bool
Supports
(
const
ComponentSpec
&
component_spec
,
const
string
&
normalized_builder_name
)
const
override
{
return
component_spec
.
backend
().
registered_name
()
==
"CauseConflict"
;
}
bool
PreferredTo
(
const
Component
&
other
)
const
override
{
return
true
;
}
};
DRAGNN_RUNTIME_REGISTER_COMPONENT
(
ImTheBest
);
class
BulkDynamicComponentTest
:
public
NetworkTestBase
{
protected:
// Returns a spec that the network supports.
ComponentSpec
GetSupportedSpec
()
{
ComponentSpec
component_spec
;
component_spec
.
set_name
(
kTestComponentName
);
component_spec
.
set_num_actions
(
1
);
component_spec
.
mutable_network_unit
()
->
set_registered_name
(
"AddOne"
);
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"DynamicComponent"
);
FixedFeatureChannel
*
fixed_feature
=
component_spec
.
add_fixed_feature
();
fixed_feature
->
set_size
(
1
);
fixed_feature
->
set_embedding_dim
(
kFixedDim
);
fixed_feature
->
set_vocabulary_size
(
kFixedVocabularySize
);
LinkedFeatureChannel
*
linked_feature
=
component_spec
.
add_linked_feature
();
linked_feature
->
set_size
(
1
);
linked_feature
->
set_embedding_dim
(
-
1
);
linked_feature
->
set_source_component
(
kPreviousComponentName
);
linked_feature
->
set_source_layer
(
kPreviousLayerName
);
return
component_spec
;
}
// Adds mock call expectations to the |compute_session_| for the transition
// system traversal and feature extraction.
void
AddComputeSessionMocks
()
{
SetupTransitionLoop
(
kNumSteps
);
EXPECT_CALL
(
compute_session_
,
AdvanceFromOracle
(
_
)).
Times
(
kNumSteps
);
EXPECT_CALL
(
compute_session_
,
GetInputFeatures
(
_
,
_
,
_
,
_
,
_
))
.
Times
(
kNumSteps
)
.
WillRepeatedly
(
Invoke
(
ExtractFeatures
(
0
,
{{
0
,
1.0
}})));
EXPECT_CALL
(
compute_session_
,
GetTranslatedLinkFeatures
(
_
,
_
))
.
Times
(
kNumSteps
)
.
WillRepeatedly
(
Invoke
(
ExtractLinks
(
0
,
{
"step_idx: 0"
})));
EXPECT_CALL
(
compute_session_
,
SourceComponentBeamSize
(
_
,
_
))
.
Times
(
kNumSteps
)
.
WillRepeatedly
(
Return
(
1
));
}
// Creates a network unit, initializes it based on the |component_spec_text|,
// and evaluates it. On error, returns non-OK.
tensorflow
::
Status
Run
(
const
ComponentSpec
&
component_spec
)
{
AddComponent
(
kPreviousComponentName
);
AddLayer
(
kPreviousLayerName
,
kLinkedDim
);
AddComponent
(
kTestComponentName
);
AddFixedEmbeddingMatrix
(
0
,
kFixedVocabularySize
,
kFixedDim
,
kFixedValue
);
TF_RETURN_IF_ERROR
(
Component
::
CreateOrError
(
"BulkDynamicComponent"
,
&
component_
));
TF_RETURN_IF_ERROR
(
component_
->
Initialize
(
component_spec
,
&
variable_store_
,
&
network_state_manager_
,
&
extension_manager_
));
// Allocates network states for a few steps.
network_states_
.
Reset
(
&
network_state_manager_
);
StartComponent
(
kNumSteps
);
FillLayer
(
kPreviousComponentName
,
kPreviousLayerName
,
kLinkedValue
);
StartComponent
(
0
);
session_state_
.
extensions
.
Reset
(
&
extension_manager_
);
TF_RETURN_IF_ERROR
(
component_
->
Evaluate
(
&
session_state_
,
&
compute_session_
,
nullptr
));
outputs_
=
GetLayer
(
kTestComponentName
,
kOutputsName
);
return
tensorflow
::
Status
::
OK
();
}
std
::
unique_ptr
<
Component
>
component_
;
Matrix
<
float
>
outputs_
;
};
// Tests that the supported spec is supported.
TEST_F
(
BulkDynamicComponentTest
,
Supported
)
{
const
ComponentSpec
component_spec
=
GetSupportedSpec
();
string
component_type
;
TF_ASSERT_OK
(
Component
::
Select
(
component_spec
,
&
component_type
));
EXPECT_EQ
(
component_type
,
"BulkDynamicComponent"
);
AddComputeSessionMocks
();
TF_ASSERT_OK
(
Run
(
component_spec
));
ASSERT_EQ
(
outputs_
.
num_rows
(),
kNumSteps
);
ASSERT_EQ
(
outputs_
.
num_columns
(),
kFixedDim
+
kLinkedDim
);
for
(
size_t
row
=
0
;
row
<
kNumSteps
;
++
row
)
{
size_t
column
=
0
;
for
(;
column
<
kFixedDim
;
++
column
)
{
EXPECT_EQ
(
outputs_
.
row
(
row
)[
column
],
kFixedValue
+
1.0
);
}
for
(;
column
<
kFixedDim
+
kLinkedDim
;
++
column
)
{
EXPECT_EQ
(
outputs_
.
row
(
row
)[
column
],
kLinkedValue
+
1.0
);
}
}
}
// Tests that the BulkDynamicComponent also supports its own name.
TEST_F
(
BulkDynamicComponentTest
,
SupportsBulkName
)
{
ComponentSpec
component_spec
=
GetSupportedSpec
();
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"BulkDynamicComponent"
);
string
component_type
;
TF_ASSERT_OK
(
Component
::
Select
(
component_spec
,
&
component_type
));
EXPECT_EQ
(
component_type
,
"BulkDynamicComponent"
);
}
// Tests that the transition system must be deterministic.
TEST_F
(
BulkDynamicComponentTest
,
ForbidNonDeterminism
)
{
ComponentSpec
component_spec
=
GetSupportedSpec
();
component_spec
.
set_num_actions
(
100
);
string
component_type
;
EXPECT_THAT
(
Component
::
Select
(
component_spec
,
&
component_type
),
test
::
IsErrorWithSubstr
(
"Could not find a best spec for component"
));
EXPECT_THAT
(
Run
(
component_spec
),
test
::
IsErrorWithSubstr
(
"BulkFeatureExtractorNetwork does not support component"
));
}
// Tests that links cannot be recurrent.
TEST_F
(
BulkDynamicComponentTest
,
ForbidRecurrences
)
{
ComponentSpec
component_spec
=
GetSupportedSpec
();
component_spec
.
mutable_linked_feature
(
0
)
->
set_source_component
(
kTestComponentName
);
string
component_type
;
EXPECT_THAT
(
Component
::
Select
(
component_spec
,
&
component_type
),
test
::
IsErrorWithSubstr
(
"Could not find a best spec for component"
));
EXPECT_THAT
(
Run
(
component_spec
),
test
::
IsErrorWithSubstr
(
"BulkFeatureExtractorNetwork does not support component"
));
}
// Tests that the component prefers itself.
TEST_F
(
BulkDynamicComponentTest
,
PrefersItself
)
{
ComponentSpec
component_spec
=
GetSupportedSpec
();
component_spec
.
mutable_backend
()
->
set_registered_name
(
"CauseConflict"
);
// The "CauseConflict" backend triggers the ImTheBest component, which also
// prefers itself and leads to a selection conflict.
string
component_type
;
EXPECT_THAT
(
Component
::
Select
(
component_spec
,
&
component_type
),
test
::
IsErrorWithSubstr
(
"both think they should be preferred"
));
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/bulk_feed_forward_network.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <string>
#include "dragnn/core/compute_session.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/bulk_network_unit.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/feed_forward_network_kernel.h"
#include "dragnn/runtime/feed_forward_network_layer.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// A network unit that evaluates a feed-forward multi-layer perceptron.
class
BulkFeedForwardNetwork
:
public
BulkNetworkUnit
{
public:
// Implements BulkNetworkUnit.
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
override
;
tensorflow
::
Status
ValidateInputDimension
(
size_t
dimension
)
const
override
;
string
GetLogitsName
()
const
override
{
return
kernel_
.
logits_name
();
}
tensorflow
::
Status
Evaluate
(
Matrix
<
float
>
inputs
,
SessionState
*
session_state
)
const
override
;
private:
// Kernel that implements the feed-forward network.
FeedForwardNetworkKernel
kernel_
;
};
tensorflow
::
Status
BulkFeedForwardNetwork
::
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
{
for
(
const
LinkedFeatureChannel
&
channel
:
component_spec
.
linked_feature
())
{
if
(
channel
.
source_component
()
==
component_spec
.
name
())
{
return
tensorflow
::
errors
::
InvalidArgument
(
"BulkFeedForwardNetwork forbids recurrent links"
);
}
}
return
kernel_
.
Initialize
(
component_spec
,
variable_store
,
network_state_manager
);
}
tensorflow
::
Status
BulkFeedForwardNetwork
::
ValidateInputDimension
(
size_t
dimension
)
const
{
return
kernel_
.
ValidateInputDimension
(
dimension
);
}
tensorflow
::
Status
BulkFeedForwardNetwork
::
Evaluate
(
Matrix
<
float
>
inputs
,
SessionState
*
session_state
)
const
{
for
(
const
FeedForwardNetworkLayer
&
layer
:
kernel_
.
layers
())
{
inputs
=
layer
.
Apply
(
inputs
,
session_state
->
network_states
);
}
return
tensorflow
::
Status
::
OK
();
}
DRAGNN_RUNTIME_REGISTER_BULK_NETWORK_UNIT
(
BulkFeedForwardNetwork
);
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/bulk_feed_forward_network_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <algorithm>
#include <memory>
#include <string>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/bulk_network_unit.h"
#include "dragnn/runtime/flexible_matrix_kernel.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
constexpr
size_t
kInputDim
=
5
;
constexpr
size_t
kLogitsDim
=
3
;
constexpr
size_t
kNumSteps
=
4
;
constexpr
float
kEmbedding
=
1.25
;
// Applies the ReLU activation to the |value|.
float
Relu
(
float
value
)
{
return
std
::
max
(
0.0
f
,
value
);
}
class
BulkFeedForwardNetworkTest
:
public
NetworkTestBase
{
protected:
// Adds a weight matrix with the |name_suffix| with the given dimensions and
// |fill_value|.
void
AddWeights
(
const
string
&
name_suffix
,
size_t
num_rows
,
size_t
num_columns
,
float
fill_value
)
{
const
string
weights_name
=
tensorflow
::
strings
::
StrCat
(
kTestComponentName
,
"/weights_"
,
name_suffix
,
FlexibleMatrixKernel
::
kSuffix
);
AddMatrixVariable
(
weights_name
,
num_columns
,
num_rows
,
fill_value
);
}
// Adds a bias vector with the |name_suffix| with the given dimensions and
// |fill_value|.
void
AddBiases
(
const
string
&
name_suffix
,
size_t
dimension
,
float
fill_value
)
{
const
string
biases_name
=
tensorflow
::
strings
::
StrCat
(
kTestComponentName
,
"/bias_"
,
name_suffix
);
AddVectorVariable
(
biases_name
,
dimension
,
fill_value
);
}
// Creates a network unit, initializes it based on the |component_spec_text|,
// and evaluates it. On error, returns non-OK.
tensorflow
::
Status
Run
(
const
string
&
component_spec_text
)
{
ComponentSpec
component_spec
;
CHECK
(
TextFormat
::
ParseFromString
(
component_spec_text
,
&
component_spec
));
component_spec
.
set_name
(
kTestComponentName
);
AddComponent
(
kTestComponentName
);
TF_CHECK_OK
(
BulkNetworkUnit
::
CreateOrError
(
"BulkFeedForwardNetwork"
,
&
bulk_network_unit_
));
TF_RETURN_IF_ERROR
(
bulk_network_unit_
->
Initialize
(
component_spec
,
&
variable_store_
,
&
network_state_manager_
,
&
extension_manager_
));
size_t
input_dimension
=
0
;
for
(
const
FixedFeatureChannel
&
channel
:
component_spec
.
fixed_feature
())
{
input_dimension
+=
channel
.
embedding_dim
();
}
TF_RETURN_IF_ERROR
(
bulk_network_unit_
->
ValidateInputDimension
(
input_dimension
));
network_states_
.
Reset
(
&
network_state_manager_
);
StartComponent
(
kNumSteps
);
session_state_
.
extensions
.
Reset
(
&
extension_manager_
);
const
std
::
vector
<
float
>
row
(
kInputDim
,
kEmbedding
);
UniqueMatrix
<
float
>
input
(
std
::
vector
<
std
::
vector
<
float
>>
(
kNumSteps
,
row
));
return
bulk_network_unit_
->
Evaluate
(
Matrix
<
float
>
(
*
input
),
&
session_state_
);
}
// Returns the layer named |layer_name| in the current component.
Matrix
<
float
>
GetActivations
(
const
string
&
layer_name
)
const
{
return
Matrix
<
float
>
(
GetLayer
(
kTestComponentName
,
layer_name
));
}
std
::
unique_ptr
<
BulkNetworkUnit
>
bulk_network_unit_
;
};
// Tests that BulkFeedForwardNetwork fails when a weight matrix does not match
// the dimension of its output activations.
TEST_F
(
BulkFeedForwardNetworkTest
,
BadWeightRows
)
{
const
string
kBadSpec
=
R"(fixed_feature {
vocabulary_size: 50
embedding_dim: 5
size: 1
}
num_actions: 3)"
;
AddWeights
(
"softmax"
,
kInputDim
,
kLogitsDim
-
1
/* bad */
,
1.0
);
AddBiases
(
"softmax"
,
kLogitsDim
,
1.0
);
EXPECT_THAT
(
Run
(
kBadSpec
),
test
::
IsErrorWithSubstr
(
"Weight matrix shape should be output dimension plus padding"
));
}
// Tests that BulkFeedForwardNetwork fails when a weight matrix does not match
// the dimension of its input activations.
TEST_F
(
BulkFeedForwardNetworkTest
,
BadWeightColumns
)
{
const
string
kBadSpec
=
R"(fixed_feature {
vocabulary_size: 50
embedding_dim: 5
size: 1
}
num_actions: 3)"
;
AddWeights
(
"softmax"
,
kInputDim
+
1
/* bad */
,
kLogitsDim
,
1.0
);
AddBiases
(
"softmax"
,
kLogitsDim
,
1.0
);
EXPECT_THAT
(
Run
(
kBadSpec
),
test
::
IsErrorWithSubstr
(
"Weight matrix shape does not match input dimension"
));
}
// Tests that BulkFeedForwardNetwork fails when a bias vector does not match the
// dimension of its output activations.
TEST_F
(
BulkFeedForwardNetworkTest
,
BadBiasDimension
)
{
const
string
kBadSpec
=
R"(fixed_feature {
vocabulary_size: 50
embedding_dim: 5
size: 1
}
num_actions: 3)"
;
AddWeights
(
"softmax"
,
kInputDim
,
kLogitsDim
,
1.0
);
AddBiases
(
"softmax"
,
kLogitsDim
+
1
/* bad */
,
1.0
);
EXPECT_THAT
(
Run
(
kBadSpec
),
test
::
IsErrorWithSubstr
(
"Bias vector shape does not match output dimension"
));
}
// Tests that BulkFeedForwardNetwork fails when the value of the
// "layer_norm_input" option is not false.
TEST_F
(
BulkFeedForwardNetworkTest
,
UnsupportedLayerNormInputOption
)
{
const
string
kBadSpec
=
R"(network_unit {
parameters {
key: 'layer_norm_input'
value: 'true'
}
})"
;
EXPECT_THAT
(
Run
(
kBadSpec
),
test
::
IsErrorWithSubstr
(
"Layer norm is not supported"
));
}
// Tests that BulkFeedForwardNetwork fails when the value of the
// "layer_norm_hidden" option is not false.
TEST_F
(
BulkFeedForwardNetworkTest
,
UnsupportedLayerNormHiddenOption
)
{
const
string
kBadSpec
=
R"(network_unit {
parameters {
key: 'layer_norm_hidden'
value: 'true'
}
})"
;
EXPECT_THAT
(
Run
(
kBadSpec
),
test
::
IsErrorWithSubstr
(
"Layer norm is not supported"
));
}
// Tests that BulkFeedForwardNetwork fails when the value of the "nonlinearity"
// option is not "relu".
TEST_F
(
BulkFeedForwardNetworkTest
,
UnsupportedNonlinearityOption
)
{
const
string
kBadSpec
=
R"(network_unit {
parameters {
key: 'nonlinearity'
value: 'elu'
}
})"
;
EXPECT_THAT
(
Run
(
kBadSpec
),
test
::
IsErrorWithSubstr
(
"Non-linearity is not supported"
));
}
// Tests that BulkFeedForwardNetwork fails if there is a recurrent link.
TEST_F
(
BulkFeedForwardNetworkTest
,
UnsupportedRecurrentLink
)
{
const
string
kBadSpec
=
R"(linked_feature {
source_component: 'test_component'
})"
;
EXPECT_THAT
(
Run
(
kBadSpec
),
test
::
IsErrorWithSubstr
(
"BulkFeedForwardNetwork forbids recurrent links"
));
}
// Tests that the BulkFeedForwardNetwork works when there are no hidden layers,
// just a softmax that computes logits.
TEST_F
(
BulkFeedForwardNetworkTest
,
JustLogits
)
{
const
string
kSpec
=
R"(fixed_feature {
vocabulary_size: 50
embedding_dim: 5
size: 1
}
num_actions: 3)"
;
const
float
kWeight
=
1.5
;
const
float
kBias
=
0.75
;
AddWeights
(
"softmax"
,
kInputDim
,
kLogitsDim
,
kWeight
);
AddBiases
(
"softmax"
,
kLogitsDim
,
kBias
);
TF_ASSERT_OK
(
Run
(
kSpec
));
EXPECT_EQ
(
"logits"
,
bulk_network_unit_
->
GetLogitsName
());
ExpectMatrix
(
GetActivations
(
"logits"
),
kNumSteps
,
kLogitsDim
,
kInputDim
*
kEmbedding
*
kWeight
+
kBias
);
}
// Tests that the BulkFeedForwardNetwork works with multiple hidden layers as
// well as a softmax that computes logits.
TEST_F
(
BulkFeedForwardNetworkTest
,
MultiLayer
)
{
const
size_t
kDims
[]
=
{
kInputDim
,
4
,
3
,
2
};
const
string
kSpec
=
R"(fixed_feature {
vocabulary_size: 50
embedding_dim: 5
size: 1
}
network_unit {
parameters {
key: 'hidden_layer_sizes'
value: '4,3'
}
}
num_actions: 2)"
;
const
float
kWeights
[]
=
{
-
1.5
,
1.0
,
0.5
};
const
float
kBiases
[]
=
{
0.75
,
-
0.5
,
-
1.0
};
AddWeights
(
"0"
,
kDims
[
0
],
kDims
[
1
],
kWeights
[
0
]);
AddBiases
(
"0"
,
kDims
[
1
],
kBiases
[
0
]);
AddWeights
(
"1"
,
kDims
[
1
],
kDims
[
2
],
kWeights
[
1
]);
AddBiases
(
"1"
,
kDims
[
2
],
kBiases
[
1
]);
AddWeights
(
"softmax"
,
kDims
[
2
],
kDims
[
3
],
kWeights
[
2
]);
AddBiases
(
"softmax"
,
kDims
[
3
],
kBiases
[
2
]);
TF_ASSERT_OK
(
Run
(
kSpec
));
EXPECT_EQ
(
"logits"
,
bulk_network_unit_
->
GetLogitsName
());
float
expected
=
Relu
(
kDims
[
0
]
*
kWeights
[
0
]
+
kBiases
[
0
]);
ExpectMatrix
(
GetActivations
(
"layer_0"
),
kNumSteps
,
kDims
[
1
],
expected
);
expected
=
Relu
(
kDims
[
1
]
*
expected
*
kWeights
[
1
]
+
kBiases
[
1
]);
ExpectMatrix
(
GetActivations
(
"layer_1"
),
kNumSteps
,
kDims
[
2
],
expected
);
ExpectMatrix
(
GetActivations
(
"last_layer"
),
kNumSteps
,
kDims
[
2
],
expected
);
expected
=
kDims
[
2
]
*
expected
*
kWeights
[
2
]
+
kBiases
[
2
];
ExpectMatrix
(
GetActivations
(
"logits"
),
kNumSteps
,
kDims
[
3
],
expected
);
}
// Tests that the BulkFeedForwardNetwork does not produce logits and does not
// use the softmax variables when the component is deterministic.
TEST_F
(
BulkFeedForwardNetworkTest
,
NoLogitsOrSoftmaxWhenDeterministic
)
{
const
size_t
kDims
[]
=
{
kInputDim
,
4
};
const
string
kSpec
=
R"(num_actions: 1
fixed_feature {
vocabulary_size: 50
embedding_dim: 5
size: 1
}
network_unit {
parameters {
key: 'hidden_layer_sizes'
value: '4'
}
})"
;
const
float
kWeight
=
-
1.5
;
const
float
kBias
=
0.75
;
// No "softmax" weights or biases.
AddWeights
(
"0"
,
kDims
[
0
],
kDims
[
1
],
kWeight
);
AddBiases
(
"0"
,
kDims
[
1
],
kBias
);
TF_ASSERT_OK
(
Run
(
kSpec
));
// No specified logits layer.
EXPECT_TRUE
(
bulk_network_unit_
->
GetLogitsName
().
empty
());
// No "logits" layer.
size_t
unused_dimension
=
0
;
LayerHandle
<
float
>
unused_handle
;
EXPECT_THAT
(
network_state_manager_
.
LookupLayer
(
kTestComponentName
,
"logits"
,
&
unused_dimension
,
&
unused_handle
),
test
::
IsErrorWithSubstr
(
"Unknown layer 'logits' in component 'test_component'"
));
// Hidden layer is still produced.
const
float
kExpected
=
Relu
(
kDims
[
0
]
*
kEmbedding
*
kWeight
+
kBias
);
ExpectMatrix
(
GetActivations
(
"layer_0"
),
kNumSteps
,
kDims
[
1
],
kExpected
);
ExpectMatrix
(
GetActivations
(
"last_layer"
),
kNumSteps
,
kDims
[
1
],
kExpected
);
}
// Tests that the BulkFeedForwardNetwork does not produce logits when
// omit_logits is true, even if there are actions.
TEST_F
(
BulkFeedForwardNetworkTest
,
NoLogitsOrSoftmaxWhenOmitLogitsTrue
)
{
const
size_t
kDims
[]
=
{
kInputDim
,
4
};
const
string
kSpec
=
R"(fixed_feature {
vocabulary_size: 50
embedding_dim: 5
size: 1
}
network_unit {
parameters {
key: 'hidden_layer_sizes'
value: '4'
}
parameters {
key: 'omit_logits'
value: 'true'
}
}
num_actions: 10)"
;
const
float
kWeight
=
1.5
;
const
float
kBias
=
0.75
;
// No "softmax" weights or biases.
AddWeights
(
"0"
,
kDims
[
0
],
kDims
[
1
],
kWeight
);
AddBiases
(
"0"
,
kDims
[
1
],
kBias
);
TF_ASSERT_OK
(
Run
(
kSpec
));
// No specified logits layer.
EXPECT_TRUE
(
bulk_network_unit_
->
GetLogitsName
().
empty
());
// No "logits" layer.
size_t
unused_dimension
=
0
;
LayerHandle
<
float
>
unused_handle
;
EXPECT_THAT
(
network_state_manager_
.
LookupLayer
(
kTestComponentName
,
"logits"
,
&
unused_dimension
,
&
unused_handle
),
test
::
IsErrorWithSubstr
(
"Unknown layer 'logits' in component 'test_component'"
));
// Hidden layer is still produced.
const
float
kExpected
=
kDims
[
0
]
*
kEmbedding
*
kWeight
+
kBias
;
ExpectMatrix
(
GetActivations
(
"layer_0"
),
kNumSteps
,
kDims
[
1
],
kExpected
);
ExpectMatrix
(
GetActivations
(
"last_layer"
),
kNumSteps
,
kDims
[
1
],
kExpected
);
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
Prev
1
2
3
4
5
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment