Commit edea2b67 authored by Terry Koo's avatar Terry Koo
Browse files

Remove runtime because reasons.

parent a4bb31d0
package(
default_visibility = ["//visibility:public"],
)
load(
"@org_tensorflow//tensorflow:tensorflow.bzl",
"if_linux_x86_64",
)
load(
"//dragnn/runtime:multiarch.bzl",
"dragnn_cc_multiarch_binary",
"dragnn_cc_multiarch_library",
"dragnn_cc_multiarch_test",
)
FAST_MATH_COPTS = if_linux_x86_64([
"-O3",
"-msse4.2",
"-ffast-math",
"-ftree-vectorize",
])
filegroup(
name = "test_rnn_tagger",
srcs = glob(["testdata/rnn_tagger/**"]),
)
cc_library(
name = "alignment",
hdrs = ["alignment.h"],
deps = [
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "alignment_test",
size = "small",
srcs = ["alignment_test.cc"],
deps = [
":alignment",
"//dragnn/core/test:generic",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "mmap",
srcs = ["mmap.cc"],
hdrs = ["mmap.h"],
deps = [
":alignment",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "mmap_test",
size = "small",
srcs = ["mmap_test.cc"],
data = [
"testdata/empty_file",
"testdata/ten_bytes",
],
deps = [
":mmap",
"//dragnn/core/test:generic",
"//syntaxnet:base",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "operands",
srcs = ["operands.cc"],
hdrs = ["operands.h"],
deps = [
":alignment",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "operands_test",
size = "small",
srcs = ["operands_test.cc"],
deps = [
":alignment",
":operands",
"//dragnn/runtime/math:types",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "variable_store",
hdrs = ["variable_store.h"],
deps = [
":alignment",
"//dragnn/protos:runtime_proto_cc",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "variable_store_test",
size = "small",
srcs = ["variable_store_test.cc"],
deps = [
":variable_store",
"//dragnn/core/test:generic",
"//dragnn/protos:runtime_proto_cc",
"//dragnn/runtime/test:fake_variable_store",
"//dragnn/runtime/test:helpers",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "trained_model",
srcs = ["trained_model.cc"],
hdrs = ["trained_model.h"],
deps = [
"//dragnn/core:dragnn_bulk_ops_cc",
"//dragnn/core:dragnn_ops_cc",
"//syntaxnet:base",
"//syntaxnet:parser_ops_cc",
"@org_tensorflow//tensorflow/cc/saved_model:loader",
"@org_tensorflow//tensorflow/cc/saved_model:tag_constants",
"@org_tensorflow//tensorflow/core:core_cpu",
"@org_tensorflow//tensorflow/core:framework_headers_lib",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:protos_all_cc",
],
)
cc_test(
name = "trained_model_test",
size = "small",
timeout = "moderate",
srcs = ["trained_model_test.cc"],
data = [":test_rnn_tagger"],
deps = [
":trained_model",
"//dragnn/components/syntaxnet:syntaxnet_component",
"//dragnn/core/test:generic",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "trained_model_variable_store",
srcs = ["trained_model_variable_store.cc"],
hdrs = ["trained_model_variable_store.h"],
deps = [
":alignment",
":trained_model",
":variable_store",
"//dragnn/protos:runtime_proto_cc",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:core_cpu",
"@org_tensorflow//tensorflow/core:framework_headers_lib",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:protos_all_cc",
"@org_tensorflow//tensorflow/core:tensorflow",
],
)
cc_test(
name = "trained_model_variable_store_test",
size = "small",
timeout = "moderate",
srcs = ["trained_model_variable_store_test.cc"],
data = [":test_rnn_tagger"],
shard_count = 13,
deps = [
":trained_model_variable_store",
"//dragnn/components/syntaxnet:syntaxnet_component",
"//dragnn/core/test:generic",
"//dragnn/protos:runtime_proto_cc",
"//dragnn/runtime/math:avx_vector_array",
"//dragnn/runtime/math:float16_types",
"//syntaxnet:base",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "variable_store_wrappers",
srcs = ["variable_store_wrappers.cc"],
hdrs = ["variable_store_wrappers.h"],
deps = [
":alignment",
":flexible_matrix_kernel",
":variable_store",
"//dragnn/protos:runtime_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "variable_store_wrappers_test",
size = "small",
srcs = ["variable_store_wrappers_test.cc"],
deps = [
":flexible_matrix_kernel",
":variable_store_wrappers",
"//dragnn/core/test:generic",
"//dragnn/protos:runtime_proto_cc",
"//dragnn/runtime/math:transformations",
"//dragnn/runtime/math:types",
"//dragnn/runtime/test:fake_variable_store",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "array_variable_store",
srcs = ["array_variable_store.cc"],
hdrs = ["array_variable_store.h"],
deps = [
":alignment",
":variable_store",
"//dragnn/protos:runtime_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "array_variable_store_test",
size = "small",
srcs = ["array_variable_store_test.cc"],
data = [
"testdata/array_variable_store_data",
"testdata/array_variable_store_spec",
"testdata/empty_file",
],
deps = [
":alignment",
":array_variable_store",
":file_array_variable_store",
":mmap_array_variable_store",
"//dragnn/core/test:generic",
"//dragnn/protos:runtime_proto_cc",
"//dragnn/runtime/math:types",
"//dragnn/runtime/test:helpers",
"//syntaxnet:base",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "array_variable_store_builder",
srcs = ["array_variable_store_builder.cc"],
hdrs = ["array_variable_store_builder.h"],
deps = [
":alignment",
":array_variable_store",
":variable_store_wrappers",
"//dragnn/protos:runtime_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "array_variable_store_builder_test",
size = "small",
srcs = ["array_variable_store_builder_test.cc"],
data = [
"testdata/array_variable_store_data",
"testdata/array_variable_store_spec",
],
deps = [
":alignment",
":array_variable_store_builder",
"//dragnn/core/test:generic",
"//dragnn/protos:runtime_proto_cc",
"//dragnn/runtime/test:helpers",
"//syntaxnet:base",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
# Tested in array_variable_store_test.
cc_library(
name = "file_array_variable_store",
srcs = ["file_array_variable_store.cc"],
hdrs = ["file_array_variable_store.h"],
deps = [
":alignment",
":array_variable_store",
"//dragnn/protos:runtime_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
)
# Tested in array_variable_store_test.
cc_library(
name = "mmap_array_variable_store",
srcs = ["mmap_array_variable_store.cc"],
hdrs = ["mmap_array_variable_store.h"],
deps = [
":array_variable_store",
":mmap",
"//dragnn/protos:runtime_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_library(
name = "network_states",
srcs = ["network_states.cc"],
hdrs = ["network_states.h"],
deps = [
":alignment",
":operands",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "network_states_test",
size = "small",
srcs = ["network_states_test.cc"],
deps = [
":alignment",
":network_states",
"//dragnn/core/test:generic",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "extensions",
srcs = ["extensions.cc"],
hdrs = ["extensions.h"],
deps = [
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "extensions_test",
size = "small",
srcs = ["extensions_test.cc"],
deps = [
":extensions",
"//syntaxnet:base",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
dragnn_cc_multiarch_library(
name = "linked_embeddings",
srcs = ["linked_embeddings.cc"],
hdrs = ["linked_embeddings.h"],
copts = FAST_MATH_COPTS,
opts_self = True,
deps = [
":alignment",
":flexible_matrix_kernel",
":network_states",
":variable_store",
"//dragnn/core:compute_session",
"//dragnn/protos:data_proto_cc",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:arithmetic",
"//dragnn/runtime/math:types",
"@org_tensorflow//tensorflow/core:lib",
],
)
dragnn_cc_multiarch_test(
name = "linked_embeddings_test",
size = "small",
srcs = ["linked_embeddings_test.cc"],
deps = [
":linked_embeddings",
":network_states",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/test:network_test_base",
"//syntaxnet:base",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
dragnn_cc_multiarch_library(
name = "fixed_embeddings",
srcs = ["fixed_embeddings.cc"],
hdrs = ["fixed_embeddings.h"],
copts = FAST_MATH_COPTS,
opts_self = True,
deps = [
":alignment",
":network_states",
":variable_store",
"//dragnn/core:compute_session",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:arithmetic",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
)
dragnn_cc_multiarch_test(
name = "fixed_embeddings_test",
size = "small",
srcs = ["fixed_embeddings_test.cc"],
deps = [
":fixed_embeddings",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/test:network_test_base",
"//syntaxnet:base",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "type_keyed_set",
hdrs = ["type_keyed_set.h"],
)
cc_test(
name = "type_keyed_set_test",
size = "small",
srcs = ["type_keyed_set_test.cc"],
deps = [
":type_keyed_set",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "session_state",
hdrs = ["session_state.h"],
deps = [
":extensions",
":network_states",
],
)
cc_library(
name = "session_state_pool",
srcs = ["session_state_pool.cc"],
hdrs = ["session_state_pool.h"],
deps = [
":session_state",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "session_state_pool_test",
size = "small",
srcs = ["session_state_pool_test.cc"],
deps = [
":session_state",
":session_state_pool",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
dragnn_cc_multiarch_library(
name = "bulk_dynamic_component",
srcs = ["bulk_dynamic_component.cc"],
deps = [
":bulk_network_unit",
":component",
":extensions",
":network_states",
":network_unit_base",
":transition_system_traits",
":variable_store",
"//dragnn/protos:spec_proto_cc",
"//dragnn/protos:trace_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
dragnn_cc_multiarch_test(
name = "bulk_dynamic_component_test",
srcs = ["bulk_dynamic_component_test.cc"],
deps = [
":bulk_dynamic_component",
":bulk_network_unit",
":component",
":extensions",
":network_states",
":variable_store",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/test:network_test_base",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
dragnn_cc_multiarch_library(
name = "sequence_bulk_dynamic_component",
srcs = ["sequence_bulk_dynamic_component.cc"],
deps = [
":bulk_network_unit",
":component",
":extensions",
":fixed_embeddings",
":linked_embeddings",
":network_states",
":sequence_model",
":session_state",
":variable_store",
"//dragnn/protos:spec_proto_cc",
"//dragnn/protos:trace_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
dragnn_cc_multiarch_test(
name = "sequence_bulk_dynamic_component_test",
srcs = ["sequence_bulk_dynamic_component_test.cc"],
deps = [
":bulk_network_unit",
":component",
":extensions",
":network_states",
":sequence_backend",
":sequence_bulk_dynamic_component",
":sequence_extractor",
":sequence_linker",
":sequence_predictor",
":variable_store",
"//dragnn/core:input_batch_cache",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//dragnn/runtime/test:network_test_base",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "component",
srcs = ["component.cc"],
hdrs = ["component.h"],
deps = [
":extensions",
":network_states",
":session_state",
":variable_store",
"//dragnn/core:compute_session",
"//dragnn/protos:spec_proto_cc",
"//dragnn/protos:trace_proto_cc",
"//syntaxnet:base",
"//syntaxnet:registry",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "component_test",
size = "small",
srcs = ["component_test.cc"],
deps = [
":component",
":extensions",
":network_states",
":session_state",
":variable_store",
"//dragnn/core:compute_session",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/protos:trace_proto_cc",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
dragnn_cc_multiarch_library(
name = "lstm_network_kernel",
srcs = ["lstm_network_kernel.cc"],
hdrs = ["lstm_network_kernel.h"],
copts = FAST_MATH_COPTS,
opts_self = True,
deps = [
":attributes",
":extensions",
":feed_forward_network_layer",
":flexible_matrix_kernel",
":network_states",
":session_state",
":transition_system_traits",
":variable_store",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/lstm_cell:cell_function",
"//dragnn/runtime/math:avx_activation_functions",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
dragnn_cc_multiarch_test(
name = "lstm_network_kernel_test",
srcs = ["lstm_network_kernel_test.cc"],
deps = [
":lstm_network_kernel",
":variable_store",
"//dragnn/core/test:generic",
"//dragnn/protos:runtime_proto_cc",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/lstm_cell:cell_function",
"//dragnn/runtime/test:helpers",
"//dragnn/runtime/test:network_test_base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
dragnn_cc_multiarch_library(
name = "lstm_network",
srcs = ["lstm_network.cc"],
copts = FAST_MATH_COPTS,
opts_self = True,
deps = [
":extensions",
":lstm_network_kernel",
":network_unit",
":network_unit_base",
":transition_system_traits",
":variable_store",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/lstm_cell:cell_function",
"//dragnn/runtime/math:avx_activation_functions",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
dragnn_cc_multiarch_test(
name = "lstm_network_test",
srcs = ["lstm_network_test.cc"],
deps = [
":flexible_matrix_kernel",
":lstm_network",
":network_unit",
":variable_store",
"//dragnn/core/test:generic",
"//dragnn/protos:runtime_proto_cc",
"//dragnn/runtime/lstm_cell:cell_function",
"//dragnn/runtime/test:network_test_base",
"@org_tensorflow//tensorflow/core:test",
],
)
dragnn_cc_multiarch_library(
name = "bulk_lstm_network",
srcs = ["bulk_lstm_network.cc"],
copts = FAST_MATH_COPTS,
opts_self = True,
deps = [
":bulk_network_unit",
":extensions",
":lstm_network_kernel",
":network_states",
":session_state",
":variable_store",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
dragnn_cc_multiarch_test(
name = "bulk_lstm_network_test",
srcs = ["bulk_lstm_network_test.cc"],
deps = [
":bulk_lstm_network",
":bulk_network_unit",
":flexible_matrix_kernel",
"//dragnn/core/test:generic",
"//dragnn/protos:runtime_proto_cc",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/lstm_cell:cell_function",
"//dragnn/runtime/test:helpers",
"//dragnn/runtime/test:network_test_base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "master",
srcs = ["master.cc"],
hdrs = ["master.h"],
deps = [
":component",
":extensions",
":network_states",
":session_state",
":session_state_pool",
":variable_store",
"//dragnn/core:compute_session",
"//dragnn/protos:runtime_proto_cc",
"//dragnn/protos:spec_proto_cc",
"//dragnn/protos:trace_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "master_test",
size = "small",
srcs = ["master_test.cc"],
deps = [
":alignment",
":component",
":extensions",
":master",
":network_states",
":session_state",
":variable_store",
"//dragnn/core/test:generic",
"//dragnn/core/test:mock_compute_session",
"//dragnn/protos:spec_proto_cc",
"//dragnn/protos:trace_proto_cc",
"//dragnn/runtime/test:fake_variable_store",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "network_unit",
srcs = ["network_unit.cc"],
hdrs = ["network_unit.h"],
deps = [
":extensions",
":network_states",
":session_state",
":variable_store",
"//dragnn/core:compute_session",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"//syntaxnet:registry",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "network_unit_test",
size = "small",
srcs = ["network_unit_test.cc"],
deps = [
":extensions",
":network_states",
":network_unit",
":session_state",
":variable_store",
"//dragnn/core:compute_session",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "bulk_network_unit",
srcs = ["bulk_network_unit.cc"],
hdrs = ["bulk_network_unit.h"],
deps = [
":extensions",
":network_states",
":network_unit",
":session_state",
":variable_store",
"//dragnn/core:compute_session",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"//syntaxnet:registry",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "bulk_network_unit_test",
size = "small",
srcs = ["bulk_network_unit_test.cc"],
deps = [
":bulk_network_unit",
":extensions",
":network_states",
":session_state",
":variable_store",
"//dragnn/core:compute_session",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "dynamic_component",
srcs = ["dynamic_component.cc"],
deps = [
":component",
":extensions",
":network_states",
":network_unit",
":session_state",
":transition_system_traits",
":variable_store",
"//dragnn/core:compute_session",
"//dragnn/protos:spec_proto_cc",
"//dragnn/protos:trace_proto_cc",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
cc_test(
name = "dynamic_component_test",
size = "small",
srcs = ["dynamic_component_test.cc"],
deps = [
":component",
":dynamic_component",
":extensions",
":network_states",
":network_unit",
":session_state",
":variable_store",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//dragnn/runtime/test:network_test_base",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
dragnn_cc_multiarch_library(
name = "network_unit_base",
srcs = ["network_unit_base.cc"],
hdrs = ["network_unit_base.h"],
deps = [
":extensions",
":fixed_embeddings",
":linked_embeddings",
":network_states",
":network_unit",
":session_state",
":variable_store",
"//dragnn/core:compute_session",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"@org_tensorflow//tensorflow/core:lib",
],
)
dragnn_cc_multiarch_test(
name = "network_unit_base_test",
size = "small",
srcs = ["network_unit_base_test.cc"],
deps = [
":extensions",
":fixed_embeddings",
":linked_embeddings",
":network_states",
":network_unit_base",
":session_state",
":variable_store",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//dragnn/runtime/test:network_test_base",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "attributes",
srcs = ["attributes.cc"],
hdrs = ["attributes.h"],
deps = [
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "attributes_test",
size = "small",
srcs = ["attributes_test.cc"],
deps = [
":attributes",
"//dragnn/core/test:generic",
"//syntaxnet:base",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "activation_functions",
hdrs = ["activation_functions.h"],
deps = [
"//dragnn/runtime/math:arithmetic",
"//dragnn/runtime/math:types",
],
)
cc_test(
name = "activation_functions_test",
size = "small",
srcs = ["activation_functions_test.cc"],
deps = [
":activation_functions",
"//dragnn/runtime/math:types",
"//dragnn/runtime/test:helpers",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "flexible_matrix_kernel",
srcs = ["flexible_matrix_kernel.cc"],
hdrs = ["flexible_matrix_kernel.h"],
deps = [
":alignment",
":variable_store",
"//dragnn/runtime/math:arithmetic",
"//dragnn/runtime/math:avx_vector_array",
"//dragnn/runtime/math:sgemvv",
"//dragnn/runtime/math:types",
"@org_tensorflow//tensorflow/core:lib",
],
)
dragnn_cc_multiarch_test(
name = "flexible_matrix_kernel_test",
srcs = ["flexible_matrix_kernel_test.cc"],
copts = FAST_MATH_COPTS,
deps = [
":flexible_matrix_kernel",
"//dragnn/core/test:generic",
"//dragnn/protos:runtime_proto_cc",
"//dragnn/runtime/math:transformations",
"//dragnn/runtime/test:fake_variable_store",
"//dragnn/runtime/test:helpers",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
dragnn_cc_multiarch_library(
name = "feed_forward_network_layer",
srcs = ["feed_forward_network_layer.cc"],
hdrs = ["feed_forward_network_layer.h"],
copts = FAST_MATH_COPTS,
opts_self = True,
deps = [
":activation_functions",
":flexible_matrix_kernel",
":network_states",
":variable_store",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
dragnn_cc_multiarch_test(
name = "feed_forward_network_layer_test",
size = "small",
srcs = ["feed_forward_network_layer_test.cc"],
deps = [
":activation_functions",
":feed_forward_network_layer",
":flexible_matrix_kernel",
":network_states",
":variable_store",
"//dragnn/core/test:generic",
"//dragnn/runtime/math:types",
"//dragnn/runtime/test:helpers",
"//dragnn/runtime/test:network_test_base",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
dragnn_cc_multiarch_library(
name = "feed_forward_network_kernel",
srcs = ["feed_forward_network_kernel.cc"],
hdrs = ["feed_forward_network_kernel.h"],
copts = FAST_MATH_COPTS,
opts_self = True,
deps = [
":activation_functions",
":attributes",
":feed_forward_network_layer",
":flexible_matrix_kernel",
":network_states",
":transition_system_traits",
":variable_store",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
dragnn_cc_multiarch_test(
name = "feed_forward_network_kernel_test",
size = "small",
srcs = ["feed_forward_network_kernel_test.cc"],
deps = [
":activation_functions",
":feed_forward_network_kernel",
":flexible_matrix_kernel",
":network_states",
":variable_store",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//dragnn/runtime/test:network_test_base",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
dragnn_cc_multiarch_library(
name = "feed_forward_network",
srcs = ["feed_forward_network.cc"],
copts = FAST_MATH_COPTS,
opts_self = True,
deps = [
":extensions",
":feed_forward_network_kernel",
":feed_forward_network_layer",
":network_states",
":network_unit",
":network_unit_base",
":session_state",
":variable_store",
"//dragnn/core:compute_session",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
dragnn_cc_multiarch_test(
name = "feed_forward_network_test",
size = "small",
srcs = ["feed_forward_network_test.cc"],
deps = [
":dynamic_component",
":feed_forward_network",
":flexible_matrix_kernel",
":network_states",
":network_unit",
":variable_store",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//dragnn/runtime/test:network_test_base",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
dragnn_cc_multiarch_library(
name = "bulk_feed_forward_network",
srcs = ["bulk_feed_forward_network.cc"],
copts = FAST_MATH_COPTS,
opts_self = True,
deps = [
":bulk_network_unit",
":extensions",
":feed_forward_network_kernel",
":feed_forward_network_layer",
":network_states",
":session_state",
":variable_store",
"//dragnn/core:compute_session",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
dragnn_cc_multiarch_test(
name = "bulk_feed_forward_network_test",
size = "small",
srcs = ["bulk_feed_forward_network_test.cc"],
deps = [
":bulk_feed_forward_network",
":bulk_network_unit",
":dynamic_component",
":flexible_matrix_kernel",
":network_states",
":variable_store",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//dragnn/runtime/test:network_test_base",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "conversion",
srcs = ["conversion.cc"],
hdrs = ["conversion.h"],
deps = [
":array_variable_store_builder",
":master",
":trained_model_variable_store",
":variable_store",
":variable_store_wrappers",
"//dragnn/protos:runtime_proto_cc",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
)
dragnn_cc_multiarch_test(
name = "conversion_test",
size = "small",
timeout = "moderate",
srcs = ["conversion_test.cc"],
data = [
"testdata/conversion_output_variables_data",
"testdata/conversion_output_variables_spec",
":test_rnn_tagger",
],
shard_count = 6,
deps = [
":conversion",
":dynamic_component",
":feed_forward_network",
":lstm_network",
"//dragnn/components/syntaxnet:syntaxnet_component",
"//dragnn/core/test:generic",
"//dragnn/protos:runtime_proto_cc",
"//syntaxnet:base",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "component_transformation",
srcs = ["component_transformation.cc"],
hdrs = ["component_transformation.h"],
deps = [
":component",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"//syntaxnet:registry",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "component_transformation_test",
size = "small",
srcs = ["component_transformation_test.cc"],
deps = [
":component_transformation",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "fml_parsing",
srcs = ["fml_parsing.cc"],
hdrs = ["fml_parsing.h"],
deps = [
":attributes",
"//syntaxnet:base",
"//syntaxnet:feature_extractor_proto_cc",
"//syntaxnet:fml_parser",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "fml_parsing_test",
size = "small",
srcs = ["fml_parsing_test.cc"],
deps = [
":fml_parsing",
"//dragnn/core/test:generic",
"//syntaxnet:base",
"//syntaxnet:feature_extractor_proto_cc",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "term_map_utils",
srcs = ["term_map_utils.cc"],
hdrs = ["term_map_utils.h"],
deps = [
":fml_parsing",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"//syntaxnet:feature_extractor_proto_cc",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "term_map_utils_test",
size = "small",
srcs = ["term_map_utils_test.cc"],
deps = [
":term_map_utils",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/test:term_map_helpers",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "transition_system_traits",
srcs = ["transition_system_traits.cc"],
hdrs = ["transition_system_traits.h"],
deps = [
":attributes",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "transition_system_traits_test",
size = "small",
srcs = ["transition_system_traits_test.cc"],
deps = [
":transition_system_traits",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "unicode_dictionary",
srcs = ["unicode_dictionary.cc"],
hdrs = ["unicode_dictionary.h"],
deps = [
"//syntaxnet:base",
"//syntaxnet:term_frequency_map",
"//util/utf8:unicodetext",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "unicode_dictionary_test",
size = "small",
timeout = "moderate",
srcs = ["unicode_dictionary_test.cc"],
deps = [
":unicode_dictionary",
"//dragnn/core/test:generic",
"//dragnn/runtime/test:term_map_helpers",
"//syntaxnet:base",
"//syntaxnet:term_frequency_map",
"//third_party/utf",
"//util/utf8:unicodetext",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "sequence_extractor",
srcs = ["sequence_extractor.cc"],
hdrs = ["sequence_extractor.h"],
deps = [
"//dragnn/core:input_batch_cache",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"//syntaxnet:registry",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "sequence_extractor_test",
size = "small",
srcs = ["sequence_extractor_test.cc"],
deps = [
":sequence_extractor",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "term_map_sequence_extractor",
hdrs = ["term_map_sequence_extractor.h"],
deps = [
":sequence_extractor",
":term_map_utils",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"//syntaxnet:shared_store",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "term_map_sequence_extractor_test",
size = "small",
srcs = ["term_map_sequence_extractor_test.cc"],
deps = [
":term_map_sequence_extractor",
"//dragnn/core:input_batch_cache",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/test:term_map_helpers",
"//syntaxnet:base",
"//syntaxnet:term_frequency_map",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "syntaxnet_character_sequence_extractor",
srcs = ["syntaxnet_character_sequence_extractor.cc"],
deps = [
":sequence_extractor",
":term_map_sequence_extractor",
":term_map_utils",
":transition_system_traits",
":unicode_dictionary",
"//dragnn/core:input_batch_cache",
"//dragnn/io:sentence_input_batch",
"//dragnn/io:syntaxnet_sentence",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"//syntaxnet:segmenter_utils",
"//util/utf8:unicodetext",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
cc_test(
name = "syntaxnet_character_sequence_extractor_test",
size = "small",
srcs = ["syntaxnet_character_sequence_extractor_test.cc"],
deps = [
":sequence_extractor",
":syntaxnet_character_sequence_extractor",
"//dragnn/core:input_batch_cache",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/test:term_map_helpers",
"//syntaxnet:base",
"//syntaxnet:sentence_proto_cc",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "syntaxnet_word_sequence_extractor",
srcs = ["syntaxnet_word_sequence_extractor.cc"],
deps = [
":sequence_extractor",
":term_map_sequence_extractor",
":term_map_utils",
":transition_system_traits",
"//dragnn/core:input_batch_cache",
"//dragnn/io:sentence_input_batch",
"//dragnn/io:syntaxnet_sentence",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"//syntaxnet:term_frequency_map",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
cc_test(
name = "syntaxnet_word_sequence_extractor_test",
size = "small",
srcs = ["syntaxnet_word_sequence_extractor_test.cc"],
deps = [
":sequence_extractor",
":syntaxnet_word_sequence_extractor",
"//dragnn/core:input_batch_cache",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/test:term_map_helpers",
"//syntaxnet:base",
"//syntaxnet:sentence_proto_cc",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
dragnn_cc_multiarch_library(
name = "sequence_features",
srcs = ["sequence_features.cc"],
hdrs = ["sequence_features.h"],
deps = [
":alignment",
":fixed_embeddings",
":sequence_extractor",
"//dragnn/core:input_batch_cache",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
)
dragnn_cc_multiarch_test(
name = "sequence_features_test",
size = "small",
srcs = ["sequence_features_test.cc"],
deps = [
":fixed_embeddings",
":sequence_extractor",
":sequence_features",
"//dragnn/core:input_batch_cache",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/test:network_test_base",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "sequence_linker",
srcs = ["sequence_linker.cc"],
hdrs = ["sequence_linker.h"],
deps = [
"//dragnn/core:input_batch_cache",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"//syntaxnet:registry",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "sequence_linker_test",
size = "small",
srcs = ["sequence_linker_test.cc"],
deps = [
":sequence_linker",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "identity_sequence_linker",
srcs = ["identity_sequence_linker.cc"],
deps = [
":sequence_linker",
":transition_system_traits",
"//dragnn/core:input_batch_cache",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
cc_test(
name = "identity_sequence_linker_test",
size = "small",
srcs = ["identity_sequence_linker_test.cc"],
deps = [
":identity_sequence_linker",
":sequence_linker",
"//dragnn/core:input_batch_cache",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "reversed_sequence_linker",
srcs = ["reversed_sequence_linker.cc"],
deps = [
":sequence_linker",
":transition_system_traits",
"//dragnn/core:input_batch_cache",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
cc_test(
name = "reversed_sequence_linker_test",
size = "small",
srcs = ["reversed_sequence_linker_test.cc"],
deps = [
":reversed_sequence_linker",
":sequence_linker",
"//dragnn/core:input_batch_cache",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "recurrent_sequence_linkers",
srcs = ["recurrent_sequence_linkers.cc"],
deps = [
":sequence_linker",
":transition_system_traits",
"//dragnn/core:input_batch_cache",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
cc_test(
name = "recurrent_sequence_linkers_test",
size = "small",
srcs = ["recurrent_sequence_linkers_test.cc"],
deps = [
":recurrent_sequence_linkers",
":sequence_linker",
"//dragnn/core:input_batch_cache",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "syntaxnet_character_sequence_linkers",
srcs = ["syntaxnet_character_sequence_linkers.cc"],
deps = [
":sequence_linker",
":transition_system_traits",
"//dragnn/core:input_batch_cache",
"//dragnn/io:sentence_input_batch",
"//dragnn/io:syntaxnet_sentence",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"//util/utf8:unicodetext",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
cc_test(
name = "syntaxnet_character_sequence_linkers_test",
size = "small",
srcs = ["syntaxnet_character_sequence_linkers_test.cc"],
deps = [
":sequence_linker",
":syntaxnet_character_sequence_linkers",
"//dragnn/core:input_batch_cache",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"//syntaxnet:sentence_proto_cc",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
dragnn_cc_multiarch_library(
name = "sequence_links",
srcs = ["sequence_links.cc"],
hdrs = ["sequence_links.h"],
deps = [
":alignment",
":linked_embeddings",
":network_states",
":sequence_linker",
"//dragnn/core:input_batch_cache",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
)
dragnn_cc_multiarch_test(
name = "sequence_links_test",
size = "small",
srcs = ["sequence_links_test.cc"],
deps = [
":linked_embeddings",
":network_states",
":sequence_linker",
":sequence_links",
"//dragnn/core:input_batch_cache",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/test:network_test_base",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "sequence_predictor",
srcs = ["sequence_predictor.cc"],
hdrs = ["sequence_predictor.h"],
deps = [
"//dragnn/core:input_batch_cache",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"//syntaxnet:registry",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "sequence_predictor_test",
size = "small",
srcs = ["sequence_predictor_test.cc"],
deps = [
":sequence_predictor",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "term_map_sequence_predictor",
srcs = ["term_map_sequence_predictor.cc"],
hdrs = ["term_map_sequence_predictor.h"],
deps = [
":sequence_predictor",
":term_map_utils",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"//syntaxnet:shared_store",
"//syntaxnet:term_frequency_map",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "term_map_sequence_predictor_test",
size = "small",
srcs = ["term_map_sequence_predictor_test.cc"],
deps = [
":term_map_sequence_predictor",
"//dragnn/core:input_batch_cache",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//dragnn/runtime/test:term_map_helpers",
"//syntaxnet:base",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "syntaxnet_tag_sequence_predictor",
srcs = ["syntaxnet_tag_sequence_predictor.cc"],
deps = [
":sequence_predictor",
":term_map_sequence_predictor",
":transition_system_traits",
"//dragnn/core:input_batch_cache",
"//dragnn/io:sentence_input_batch",
"//dragnn/io:syntaxnet_sentence",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"//syntaxnet:sentence_proto_cc",
"//syntaxnet:term_frequency_map",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
cc_test(
name = "syntaxnet_tag_sequence_predictor_test",
size = "small",
srcs = ["syntaxnet_tag_sequence_predictor_test.cc"],
deps = [
":alignment",
":sequence_predictor",
":syntaxnet_tag_sequence_predictor",
"//dragnn/core:input_batch_cache",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/test:helpers",
"//dragnn/runtime/test:term_map_helpers",
"//syntaxnet:base",
"//syntaxnet:sentence_proto_cc",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "sequence_backend",
srcs = ["sequence_backend.cc"],
hdrs = ["sequence_backend.h"],
deps = [
"//dragnn/core:component_registry",
"//dragnn/core:input_batch_cache",
"//dragnn/core/interfaces:component",
"//dragnn/core/interfaces:transition_state",
"//dragnn/core/util:label",
"//dragnn/protos:data_proto_cc",
"//dragnn/protos:spec_proto_cc",
"//dragnn/protos:trace_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
cc_test(
name = "sequence_backend_test",
size = "small",
srcs = ["sequence_backend_test.cc"],
deps = [
":sequence_backend",
"//dragnn/components/util:bulk_feature_extractor",
"//dragnn/core:input_batch_cache",
"//dragnn/core/interfaces:transition_state",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/protos:trace_proto_cc",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "select_best_component_transformer",
srcs = ["select_best_component_transformer.cc"],
deps = [
":component",
":component_transformation",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
cc_test(
name = "select_best_component_transformer_test",
size = "small",
srcs = ["select_best_component_transformer_test.cc"],
deps = [
":component",
":component_transformation",
":extensions",
":select_best_component_transformer",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "sequence_component_transformer",
srcs = ["sequence_component_transformer.cc"],
deps = [
":component_transformation",
":sequence_extractor",
":sequence_linker",
":sequence_predictor",
":transition_system_traits",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
cc_test(
name = "sequence_component_transformer_test",
size = "small",
srcs = ["sequence_component_transformer_test.cc"],
deps = [
":component_transformation",
":sequence_component_transformer",
":sequence_extractor",
":sequence_linker",
":sequence_predictor",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "stateless_component_transformer",
srcs = ["stateless_component_transformer.cc"],
deps = [
":component_transformation",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
cc_test(
name = "stateless_component_transformer_test",
size = "small",
srcs = ["stateless_component_transformer_test.cc"],
deps = [
":component_transformation",
":stateless_component_transformer",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "clear_dropout_component_transformer",
srcs = ["clear_dropout_component_transformer.cc"],
deps = [
":component_transformation",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"//syntaxnet:feature_extractor_proto_cc",
"//syntaxnet:fml_parser",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
cc_test(
name = "clear_dropout_component_transformer_test",
size = "small",
srcs = ["clear_dropout_component_transformer_test.cc"],
deps = [
":clear_dropout_component_transformer",
":component_transformation",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
dragnn_cc_multiarch_library(
name = "sequence_model",
srcs = ["sequence_model.cc"],
hdrs = ["sequence_model.h"],
deps = [
":attributes",
":fixed_embeddings",
":linked_embeddings",
":network_states",
":sequence_backend",
":sequence_features",
":sequence_links",
":sequence_predictor",
":session_state",
":transition_system_traits",
"//dragnn/core:compute_session",
"//dragnn/core:input_batch_cache",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
)
dragnn_cc_multiarch_test(
name = "sequence_model_test",
size = "small",
srcs = ["sequence_model_test.cc"],
deps = [
":fixed_embeddings",
":linked_embeddings",
":network_states",
":sequence_backend",
":sequence_extractor",
":sequence_linker",
":sequence_model",
":sequence_predictor",
":session_state",
"//dragnn/core:input_batch_cache",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//dragnn/runtime/test:network_test_base",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
dragnn_cc_multiarch_library(
name = "biaffine_digraph_component",
srcs = ["biaffine_digraph_component.cc"],
copts = FAST_MATH_COPTS,
deps = [
":component",
":extensions",
":network_states",
":network_unit",
":session_state",
":variable_store",
"//dragnn/core:compute_session",
"//dragnn/protos:spec_proto_cc",
"//dragnn/protos:trace_proto_cc",
"//dragnn/runtime/math:eigen",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
dragnn_cc_multiarch_test(
name = "biaffine_digraph_component_test",
size = "small",
srcs = ["biaffine_digraph_component_test.cc"],
deps = [
":biaffine_digraph_component",
":component",
":extensions",
":network_states",
":session_state",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//dragnn/runtime/test:network_test_base",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "head_selection_component_base",
srcs = ["head_selection_component_base.cc"],
hdrs = ["head_selection_component_base.h"],
deps = [
":alignment",
":component",
":extensions",
":network_states",
":session_state",
":variable_store",
"//dragnn/core:compute_session",
"//dragnn/protos:spec_proto_cc",
"//dragnn/protos:trace_proto_cc",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "head_selection_component_base_test",
size = "small",
srcs = ["head_selection_component_base_test.cc"],
deps = [
":head_selection_component_base",
":network_states",
":session_state",
":variable_store",
"//dragnn/core:compute_session",
"//dragnn/core:input_batch_cache",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/protos:trace_proto_cc",
"//dragnn/runtime/test:network_test_base",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "syntaxnet_head_selection_component",
srcs = ["syntaxnet_head_selection_component.cc"],
deps = [
":head_selection_component_base",
":session_state",
"//dragnn/core:compute_session",
"//dragnn/core:input_batch_cache",
"//dragnn/io:sentence_input_batch",
"//dragnn/io:syntaxnet_sentence",
"//dragnn/protos:trace_proto_cc",
"//syntaxnet:base",
"//syntaxnet:sentence_proto_cc",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
cc_test(
name = "syntaxnet_head_selection_component_test",
size = "small",
srcs = ["syntaxnet_head_selection_component_test.cc"],
deps = [
":component",
":syntaxnet_head_selection_component",
":variable_store",
"//dragnn/core:compute_session",
"//dragnn/core:input_batch_cache",
"//dragnn/core/test:generic",
"//dragnn/io:sentence_input_batch",
"//dragnn/io:syntaxnet_sentence",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//dragnn/runtime/test:network_test_base",
"//syntaxnet:sentence_proto_cc",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "mst_solver_component_base",
srcs = ["mst_solver_component_base.cc"],
hdrs = ["mst_solver_component_base.h"],
deps = [
":attributes",
":component",
":extensions",
":network_states",
":network_unit",
":session_state",
":variable_store",
"//dragnn/core:compute_session",
"//dragnn/mst:mst_solver",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "mst_solver_component_base_test",
size = "small",
srcs = ["mst_solver_component_base_test.cc"],
deps = [
":mst_solver_component_base",
":network_states",
":session_state",
":variable_store",
"//dragnn/core:compute_session",
"//dragnn/core:input_batch_cache",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/protos:trace_proto_cc",
"//dragnn/runtime/test:network_test_base",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "syntaxnet_mst_solver_component",
srcs = ["syntaxnet_mst_solver_component.cc"],
deps = [
":mst_solver_component_base",
":session_state",
"//dragnn/core:compute_session",
"//dragnn/core:input_batch_cache",
"//dragnn/io:sentence_input_batch",
"//dragnn/io:syntaxnet_sentence",
"//dragnn/protos:trace_proto_cc",
"//syntaxnet:base",
"//syntaxnet:sentence_proto_cc",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
cc_test(
name = "syntaxnet_mst_solver_component_test",
size = "small",
srcs = ["syntaxnet_mst_solver_component_test.cc"],
deps = [
":component",
":syntaxnet_mst_solver_component",
":variable_store",
"//dragnn/core:compute_session",
"//dragnn/core:input_batch_cache",
"//dragnn/core/test:generic",
"//dragnn/io:sentence_input_batch",
"//dragnn/io:syntaxnet_sentence",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//dragnn/runtime/test:network_test_base",
"//syntaxnet:sentence_proto_cc",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "converter_main",
srcs = ["converter.cc"],
deps = [
":component_transformation",
":conversion",
"//dragnn/runtime/myelin:myelination",
"//dragnn/runtime/xla:xla_compilation",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@sling//sling/base",
],
)
dragnn_cc_multiarch_binary(
name = "converter",
target_arch = "generic",
deps = [
":biaffine_digraph_component",
":bulk_dynamic_component",
":bulk_feed_forward_network",
":bulk_lstm_network",
":clear_dropout_component_transformer",
":converter_main",
":dynamic_component",
":feed_forward_network",
":identity_sequence_linker",
":lstm_network",
":recurrent_sequence_linkers",
":reversed_sequence_linker",
":select_best_component_transformer",
":sequence_backend",
":sequence_bulk_dynamic_component",
":sequence_component_transformer",
":stateless_component_transformer",
":syntaxnet_character_sequence_extractor",
":syntaxnet_character_sequence_linkers",
":syntaxnet_head_selection_component",
":syntaxnet_mst_solver_component",
":syntaxnet_tag_sequence_predictor",
":syntaxnet_word_sequence_extractor",
"//dragnn/components/stateless:stateless_component",
"//dragnn/components/syntaxnet:syntaxnet_component",
"//dragnn/mst:mst_ops_cc",
"//dragnn/runtime/myelin:myelin_dynamic_component",
"//dragnn/runtime/myelin:sequence_myelin_dynamic_component",
"//dragnn/runtime/xla:xla_dynamic_component",
"//syntaxnet:parser_transitions",
],
)
sh_test(
name = "converter_test",
size = "medium",
srcs = ["converter_test.sh"],
data = [":converter"] + glob([
"testdata/converter_output/**",
"testdata/rnn_tagger/**",
]),
)
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Definitions of activation functions for neural netowrks.
#ifndef DRAGNN_RUNTIME_ACTIVATION_FUNCTIONS_H_
#define DRAGNN_RUNTIME_ACTIVATION_FUNCTIONS_H_
#include "dragnn/runtime/math/arithmetic.h"
#include "dragnn/runtime/math/types.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Possible types of activation functions.
//
// TODO(googleuser): If many activation functions are added, or if functions start
// using configuration parameters (e.g., leakiness of a leaky ReLU), then switch
// to a registered class.
enum class ActivationFunction {
kIdentity, // pass-through, useful for classification logits
kRelu, // ReLU; i.e., max(0,x)
};
// Applies the |activation_function| to the |values|.
template <class T>
void ApplyActivationFunction(ActivationFunction activation_function,
MutableVector<T> values);
// Implementation details below.
template <class T>
void ApplyActivationFunction(ActivationFunction activation_function,
MutableVector<T> values) {
switch (activation_function) {
case ActivationFunction::kIdentity:
break;
case ActivationFunction::kRelu:
MaxElements(T(), values);
break;
}
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_ACTIVATION_FUNCTIONS_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/activation_functions.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/test/helpers.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Tests that kIdentity is a pass-through.
TEST(ActivationFunctionsTest, ApplyIdentity) {
UniqueVector<float> values({1.25f, -1.5f, 0.0f, 0.0625f, -0.03125});
ApplyActivationFunction(ActivationFunction::kIdentity, *values);
EXPECT_EQ((*values)[0], 1.25);
EXPECT_EQ((*values)[1], -1.5);
EXPECT_EQ((*values)[2], 0.0);
EXPECT_EQ((*values)[3], 0.0625);
EXPECT_EQ((*values)[4], -0.03125);
}
// Tests that kRelu clips to zero.
TEST(ActivationFunctionsTest, ApplyRelu) {
UniqueVector<float> values({1.25f, -1.5f, 0.0f, 0.0625f, -0.03125});
ApplyActivationFunction(ActivationFunction::kRelu, *values);
EXPECT_EQ((*values)[0], 1.25);
EXPECT_EQ((*values)[1], 0.0); // clipped
EXPECT_EQ((*values)[2], 0.0); // boundary
EXPECT_EQ((*values)[3], 0.0625);
EXPECT_EQ((*values)[4], 0.0); // clipped
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Utils for working with aligned memory blocks. The DRAGNN runtime requires
// aligned memory for use in vectorized math. Do not rely on any particular
// value of the alignment requirement, because it will vary over time and in
// different build configurations.
#ifndef DRAGNN_RUNTIME_ALIGNMENT_H_
#define DRAGNN_RUNTIME_ALIGNMENT_H_
#include <stddef.h>
#include <stdint.h>
#include <string.h>
#include <memory>
#include <type_traits>
#include <vector>
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
// This is a type that has some private methods (so non-POD), but is known to be
// trivially-deconstructable. Ergo we add some special handling so
// IsAlignable<bfloat16> returns true.
namespace tensorflow {
struct bfloat16;
}
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Returns true if |T| can be used in an aligned memory block.
template <class T>
constexpr bool IsAlignable();
// Returns OK iff the |pointer| satisfies the alignment requirement.
tensorflow::Status OkIfAligned(const void *pointer);
// Returns the next alignment boundary at or after the |byte_offset|.
size_t PadToAlignment(size_t byte_offset);
// As above, but for pointers.
template <class T>
T *PadToAlignment(T *pointer);
// Returns the number of bytes required to store a sequence of |num_arrays|
// aligned arrays of |array_size| bytes, including alignment padding. See
// (Mutable)AlignedArea below.
size_t ComputeAlignedAreaSize(size_t num_arrays, size_t array_size);
// Returns the number of bytes required to store a sequence of byte arrays of
// the given |sizes|, including alignment padding after each array.
size_t ComputeTotalBytesWithAlignmentPadding(const std::vector<size_t> &sizes);
// Forward-declared for friendship below.
class Operands;
class UniqueAlignedArray;
enum class BlockedMatrixFormat;
namespace internal {
// A non-owning view of an aligned byte array. Templated so const and mutable
// versions can share implementation. Do not use this class directly, instead
// use (Mutable)AlignedView below.
template <class Byte>
class AlignedViewImpl {
public:
static_assert(sizeof(Byte) == 1, "Byte must be byte-sized");
// Creates an empty view.
AlignedViewImpl() = default;
// Points this at the same bytes as |that|, possibly reinterpreting type.
template <class OtherByte>
explicit AlignedViewImpl(AlignedViewImpl<OtherByte> that);
template <class OtherByte>
AlignedViewImpl &operator=(AlignedViewImpl<OtherByte> that);
// Points this at [|data|,|data|+|size|). On error, returns non-OK and
// modifies nothing.
tensorflow::Status Reset(Byte *data, size_t size);
// Splits this into a list of |views| of the |sizes|, possibly reinterpreting
// type. The |views| need not completely cover all bytes of this. Requires
// that this spans ComputeTotalBytesWithAlignmentPadding(|sizes|) bytes. On
// error, returns non-OK and modifies nothing.
template <class OtherByte>
tensorflow::Status Split(
const std::vector<size_t> &sizes,
std::vector<AlignedViewImpl<OtherByte>> *views) const;
// Accessors.
Byte *data() const { return data_; }
size_t size() const { return size_; }
bool empty() const { return size() == 0; }
private:
template <class OtherByte>
friend class AlignedViewImpl;
template <class OtherByte>
friend class AlignedAreaImpl;
friend Operands;
friend UniqueAlignedArray;
// Directly creates an aligned view, bypassing alignment checks.
AlignedViewImpl(Byte *data, size_t size);
// Pointer to the start of the view.
Byte *data_ = nullptr;
// Number of bytes in the view.
size_t size_ = 0;
};
// A non-owning view of an aligned, 2-dimensional byte array. Templated so
// const and mutable versons can share implementation. Do not use this class
// directly, instead use (Mutable)AlignedArea below.
template <class Byte>
class AlignedAreaImpl {
public:
static_assert(sizeof(Byte) == 1, "Byte must be byte-sized");
// Creates an empty area.
AlignedAreaImpl() = default;
// Points this at the same bytes as |that|, possibly reinterpreting type.
template <class OtherByte>
explicit AlignedAreaImpl(AlignedAreaImpl<OtherByte> that);
template <class OtherByte>
AlignedAreaImpl &operator=(AlignedAreaImpl<OtherByte> that);
// Resets this to a sequence of |num_views| aligned sub-views of the |view|,
// each |view_size| bytes wide. The first sub-view covers [0,|view_size|) of
// |view|, and each subsequent sub-view starts at the next alignment boundary.
// Requires that |view| spans ComputeAlignedAreaSize(|num_views|,|view_size|)
// bytes or more. On error, returns non-OK and modifies nothing.
template <class OtherByte>
tensorflow::Status Reset(AlignedViewImpl<OtherByte> view, size_t num_views,
size_t view_size);
// Accessors.
AlignedViewImpl<Byte> view(size_t index) const;
Byte *data() const { return data_; }
size_t num_views() const { return num_views_; }
size_t view_size() const { return view_size_; }
size_t view_stride() const { return view_stride_; }
bool empty() const { return num_views() == 0; }
private:
template <class OtherByte>
friend class AlignedAreaImpl;
friend Operands;
// Directly creates an aligned view, bypassing alignment checks.
AlignedAreaImpl(Byte *data, size_t num_views, size_t view_size,
size_t view_stride);
// Pointer to the start of the first view.
Byte *data_ = nullptr;
// Number of views in the area.
size_t num_views_ = 0;
// Size of each view in bytes, excluding alignment padding.
size_t view_size_ = 0;
// Number of bytes between the starts of consecutive views. NB: This is not
// necessarily equal to PadToAlignment(|view_size_|).
size_t view_stride_ = 0;
};
} // namespace internal
// Public aliases; use these.
using AlignedView = internal::AlignedViewImpl<const char>;
using AlignedArea = internal::AlignedAreaImpl<const char>;
using MutableAlignedView = internal::AlignedViewImpl<char>;
using MutableAlignedArea = internal::AlignedAreaImpl<char>;
// A uniquely-owned aligned byte array.
class UniqueAlignedArray {
public:
// Creates an empty byte array.
UniqueAlignedArray() = default;
// Reallocates this to |new_size| bytes, and discards the current byte array.
// Contents are uninitialized.
void Reset(size_t new_size);
// Like Reset(), but only reallocates if |new_size| is more than the current
// capacity. NB: Does not preserve current content when reallocation occurs;
// use Resize() if that is desired.
void Reserve(size_t new_size);
// Resizes this to contain |new_size| bytes, preserving current content. If
// |new_size| exceeds the current size, the added bytes are uninitialized. If
// |new_size| exceeds the current capacity, reallocates, and copies current
// content. Returns true if reallocation occurred.
bool Resize(size_t new_size);
// Returns the aligned byte array.
MutableAlignedView view() const { return view_; }
private:
// Underlying byte array, which is padded for alignment.
std::unique_ptr<char[]> padded_array_;
// Size of the aligned portion of |padded_array_|.
size_t capacity_ = 0;
// Active range of the |storage_|.
MutableAlignedView view_;
};
// Implementation details below.
namespace internal {
// Required alignment for memory blocks. Only the runtime framework should use
// this; otherwise, DO NOT access or otherwise depend on this value.
enum : size_t { kAlignmentBytes = 32 };
} // namespace internal
template <class T>
constexpr bool IsAlignable() {
// Either T is divisible into alignment windows, or an alignment window is
// divisible into Ts. Likewise for T's alignment requirement. Finally, T
// must be POD because we won't call its constructor or destructor.
return (sizeof(T) % internal::kAlignmentBytes == 0 ||
internal::kAlignmentBytes % sizeof(T) == 0) &&
(alignof(T) % internal::kAlignmentBytes == 0 ||
internal::kAlignmentBytes % alignof(T) == 0) &&
(std::is_pod<T>::value ||
std::is_same<T, tensorflow::bfloat16>::value);
}
inline tensorflow::Status OkIfAligned(const void *pointer) {
const uintptr_t address = reinterpret_cast<uintptr_t>(pointer);
if (address % internal::kAlignmentBytes != 0) {
return tensorflow::errors::InvalidArgument(
"Pointer fails alignment requirement: ", address, " vs required ",
internal::kAlignmentBytes);
}
return tensorflow::Status::OK();
}
inline size_t PadToAlignment(size_t byte_offset) {
// Round up to the next alignment boundary by incrementing by a certain amount
// and then rounding down. Note that the bitmask clears the low-order bits of
// the offset, effectively rounding down to the previous alignment boundary.
return (byte_offset + internal::kAlignmentBytes - 1) &
~(internal::kAlignmentBytes - 1);
}
template <class T>
T *PadToAlignment(T *pointer) {
static_assert(IsAlignable<T>(), "T is not alignable");
uintptr_t address = reinterpret_cast<uintptr_t>(pointer);
address = (address + internal::kAlignmentBytes - 1) &
~(internal::kAlignmentBytes - 1);
return reinterpret_cast<T *>(address);
}
inline size_t ComputeAlignedAreaSize(size_t num_arrays, size_t array_size) {
return num_arrays * PadToAlignment(array_size);
}
inline size_t ComputeTotalBytesWithAlignmentPadding(
const std::vector<size_t> &sizes) {
size_t total = 0;
for (const size_t size : sizes) total += PadToAlignment(size);
return total;
}
namespace internal {
template <class Byte>
template <class OtherByte>
AlignedViewImpl<Byte>::AlignedViewImpl(AlignedViewImpl<OtherByte> that)
: data_(reinterpret_cast<Byte *>(that.data())), size_(that.size()) {}
template <class Byte>
template <class OtherByte>
AlignedViewImpl<Byte> &AlignedViewImpl<Byte>::operator=(
AlignedViewImpl<OtherByte> that) {
data_ = reinterpret_cast<Byte *>(that.data());
size_ = that.size();
return *this;
}
template <class Byte>
tensorflow::Status AlignedViewImpl<Byte>::Reset(Byte *data, size_t size) {
TF_RETURN_IF_ERROR(OkIfAligned(data));
// Success; make modifications.
data_ = data;
size_ = size;
return tensorflow::Status::OK();
}
template <class Byte>
template <class OtherByte>
tensorflow::Status AlignedViewImpl<Byte>::Split(
const std::vector<size_t> &sizes,
std::vector<AlignedViewImpl<OtherByte>> *views) const {
const size_t total_bytes = ComputeTotalBytesWithAlignmentPadding(sizes);
if (size() < total_bytes) {
return tensorflow::errors::InvalidArgument(
"View is too small to be split into sizes [",
tensorflow::str_util::Join(sizes, ", "), "]: need ", total_bytes,
" bytes but have ", size(), " bytes");
}
// Success; make modifications.
views->clear();
views->reserve(sizes.size());
Byte *base = data();
for (const size_t size : sizes) {
views->push_back(AlignedViewImpl<OtherByte>(base, size));
base = PadToAlignment(base + size);
}
DCHECK_EQ(base - data(), total_bytes);
return tensorflow::Status::OK();
}
template <class Byte>
AlignedViewImpl<Byte>::AlignedViewImpl(Byte *data, size_t size)
: data_(data), size_(size) {
TF_DCHECK_OK(OkIfAligned(data_));
}
template <class Byte>
template <class OtherByte>
AlignedAreaImpl<Byte>::AlignedAreaImpl(AlignedAreaImpl<OtherByte> that)
: data_(reinterpret_cast<Byte *>(that.data_)),
num_views_(that.num_views()),
view_size_(that.view_size()),
view_stride_(that.view_stride_) {}
template <class Byte>
template <class OtherByte>
AlignedAreaImpl<Byte> &AlignedAreaImpl<Byte>::operator=(
AlignedAreaImpl<OtherByte> that) {
data_ = reinterpret_cast<Byte *>(that.data_);
num_views_ = that.num_views();
view_size_ = that.view_size();
view_stride_ = that.view_stride_;
return *this;
}
template <class Byte>
template <class OtherByte>
tensorflow::Status AlignedAreaImpl<Byte>::Reset(AlignedViewImpl<OtherByte> view,
size_t num_views,
size_t view_size) {
const size_t total_bytes = ComputeAlignedAreaSize(num_views, view_size);
if (view.size() < total_bytes) {
return tensorflow::errors::InvalidArgument(
"View is too small for area of ", num_views, " views of ", view_size,
" bytes: need ", total_bytes, " bytes but got ", view.size(), " bytes");
}
// Success; make modifications.
data_ = reinterpret_cast<Byte *>(view.data());
num_views_ = num_views;
view_size_ = view_size;
view_stride_ = PadToAlignment(view_size_);
return tensorflow::Status::OK();
}
template <class Byte>
AlignedViewImpl<Byte> AlignedAreaImpl<Byte>::view(size_t index) const {
DCHECK_LT(index, num_views());
return AlignedViewImpl<Byte>(data_ + view_stride_ * index, view_size_);
}
template <class Byte>
AlignedAreaImpl<Byte>::AlignedAreaImpl(Byte *data, size_t num_views,
size_t view_size, size_t view_stride)
: data_(data),
num_views_(num_views),
view_size_(view_size),
view_stride_(view_stride) {
TF_DCHECK_OK(OkIfAligned(data_));
TF_DCHECK_OK(OkIfAligned(static_cast<const char *>(nullptr) + view_stride_));
}
} // namespace internal
inline void UniqueAlignedArray::Reset(size_t new_size) {
// Pad the |new_size| to the next alignment boundary, so the final bytes of
// the array are still in a full alignment window. E.g., if we resize to 48
// bytes with 32-byte alignment, then we allocate 64 bytes so the final 16
// bytes are still part of a full 32-byte alignment window.
const size_t aligned_size = PadToAlignment(new_size);
// To obtain an aligned address, allocate a sufficiently-padded byte array and
// find an aligned address near the start of the block.
//
// TODO(googleuser): Alternatively, we could use library functions such as
// memalign(), posix_memalign(), or aligned_alloc(), but those may not be
// present on all platforms. Consider adding some #ifs to allow use of those
// library functions when available.
padded_array_.reset(new char[aligned_size + internal::kAlignmentBytes - 1]);
capacity_ = aligned_size;
view_.size_ = new_size;
view_.data_ = PadToAlignment(padded_array_.get());
TF_DCHECK_OK(OkIfAligned(view_.data_));
}
inline void UniqueAlignedArray::Reserve(size_t new_size) {
if (new_size > capacity_) {
Reset(new_size);
} else {
view_.size_ = new_size;
}
}
inline bool UniqueAlignedArray::Resize(size_t new_size) {
// Avoid reallocation, if possible.
if (new_size <= capacity_) {
view_.size_ = new_size;
return false;
}
// Reallocate and copy. Extend the life of the old array until it is copied.
//
// Note: C realloc() can extend a byte array in place (i.e., without copying).
// Unfortunately, there is no aligned version of realloc(). Moreover, adding
// alignment padding could cause double-copying: first, when realloc() copies
// the data to the new buffer, and second, if the amount of padding required
// at the new address is not the same as before.
const std::unique_ptr<char[]> old_array = std::move(padded_array_);
const MutableAlignedView old_view = view_;
Reset(2 * new_size);
memcpy(view_.data(), old_view.data(), old_view.size());
view_.size_ = new_size;
return true;
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_ALIGNMENT_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/alignment.h"
#include <stddef.h>
#include <string>
#include <vector>
#include "dragnn/core/test/generic.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
static_assert(internal::kAlignmentBytes >= 4, "alignment too small");
// Expects that two pointers have the same address.
void ExpectSameAddress(const void *pointer1, const void *pointer2) {
EXPECT_EQ(pointer1, pointer2);
}
// Tests that standard scalar types are alignable.
TEST(IsAlignableTest, Alignable) {
EXPECT_TRUE(IsAlignable<char>());
EXPECT_TRUE(IsAlignable<float>());
EXPECT_TRUE(IsAlignable<double>());
}
// Tests that objects of odd sizes are not alignable.
TEST(IsAlignableTest, NotAlignable) {
EXPECT_FALSE(IsAlignable<char[3]>());
EXPECT_FALSE(IsAlignable<char[7]>());
EXPECT_FALSE(IsAlignable<char[7919]>());
}
// Tests that OkIfAligned() returns OK on aligned pointers.
TEST(OkIfAlignedTest, Aligned) {
const char *ptr = nullptr;
TF_EXPECT_OK(OkIfAligned(ptr));
ptr += internal::kAlignmentBytes;
TF_EXPECT_OK(OkIfAligned(ptr));
ptr += 123 * internal::kAlignmentBytes;
TF_EXPECT_OK(OkIfAligned(ptr));
}
// Tests that OkIfAligned() returns non-OK on misaligned pointers.
TEST(OkIfAlignedTest, NotAligned) {
const char *ptr = nullptr;
EXPECT_THAT(OkIfAligned(ptr + 1),
test::IsErrorWithSubstr("Pointer fails alignment requirement"));
EXPECT_THAT(OkIfAligned(ptr + 23),
test::IsErrorWithSubstr("Pointer fails alignment requirement"));
}
// Tests that any window of |internal::kAlignmentBytes| bytes contains exactly
// one aligned address.
TEST(OkIfAlignedTest, OnePerAlignmentWindow) {
// Note that |bytes| does not necessarily start at an aligned address. Even
// so, it is still guaranteed to contain exactly one aligned address, in the
// same sense that any sequence of 10 consecutive integers contains exactly
// one whose decimal representation ends in '0'. This property is exploited
// in UniqueAlignedArray::Reset().
const string bytes(internal::kAlignmentBytes, ' ');
int num_ok = 0;
for (int i = 0; i < bytes.size(); ++i) {
if (OkIfAligned(bytes.data() + i).ok()) ++num_ok;
}
EXPECT_EQ(num_ok, 1);
}
// Tests that PadToAlignment() produces an aligned byte offset.
TEST(PadToAlignmentTest, Offset) {
EXPECT_EQ(PadToAlignment(0), 0);
EXPECT_EQ(PadToAlignment(1), internal::kAlignmentBytes);
EXPECT_EQ(PadToAlignment(internal::kAlignmentBytes + 1),
2 * internal::kAlignmentBytes);
EXPECT_EQ(PadToAlignment(99 * internal::kAlignmentBytes + 3),
100 * internal::kAlignmentBytes);
}
// Tests that PadToAlignment() produces an aligned pointer.
TEST(PadToAlignmentTest, Pointer) {
const string bytes = "hello";
TF_EXPECT_OK(OkIfAligned(PadToAlignment(bytes.data())));
const std::vector<float> reals(10);
TF_EXPECT_OK(OkIfAligned(PadToAlignment(reals.data())));
}
// Tests that ComputeAlignedAreaSize() calculates the correct size.
TEST(ComputeAlignedAreaSizeTest, Basic) {
EXPECT_EQ(ComputeAlignedAreaSize(0, 0), 0);
EXPECT_EQ(ComputeAlignedAreaSize(0, 1), 0);
EXPECT_EQ(ComputeAlignedAreaSize(1, 0), 0);
EXPECT_EQ(ComputeAlignedAreaSize(1, 1), internal::kAlignmentBytes);
EXPECT_EQ(ComputeAlignedAreaSize(1, internal::kAlignmentBytes),
internal::kAlignmentBytes);
EXPECT_EQ(ComputeAlignedAreaSize(3, internal::kAlignmentBytes + 1),
6 * internal::kAlignmentBytes);
EXPECT_EQ(ComputeAlignedAreaSize(11, internal::kAlignmentBytes - 1),
11 * internal::kAlignmentBytes);
EXPECT_EQ(ComputeAlignedAreaSize(7, internal::kAlignmentBytes),
7 * internal::kAlignmentBytes);
}
// Tests that ComputeTotalBytesWithAlignmentPadding() calculates the correct
// total size.
TEST(ComputeTotalBytesWithAlignmentPaddingTest, DifferentSizes) {
EXPECT_EQ(ComputeTotalBytesWithAlignmentPadding({}), 0);
EXPECT_EQ(ComputeTotalBytesWithAlignmentPadding({0}), 0);
EXPECT_EQ(ComputeTotalBytesWithAlignmentPadding({0, 0, 0}), 0);
EXPECT_EQ(ComputeTotalBytesWithAlignmentPadding({1}),
internal::kAlignmentBytes);
EXPECT_EQ(ComputeTotalBytesWithAlignmentPadding({1, 1, 1}),
3 * internal::kAlignmentBytes);
EXPECT_EQ(ComputeTotalBytesWithAlignmentPadding(
{1, internal::kAlignmentBytes, internal::kAlignmentBytes + 1}),
4 * internal::kAlignmentBytes);
std::vector<size_t> sizes;
for (size_t i = 1; i <= internal::kAlignmentBytes; ++i) sizes.push_back(i);
EXPECT_EQ(ComputeTotalBytesWithAlignmentPadding(sizes),
internal::kAlignmentBytes * internal::kAlignmentBytes);
}
// Tests that ComputeTotalBytesWithAlignmentPadding() is equivalent to
// ComputeAlignedAreaSize() when all sizes are equal.
TEST(ComputeTotalBytesWithAlignmentPaddingTest, AllSameSize) {
EXPECT_EQ(ComputeTotalBytesWithAlignmentPadding({1, 1, 1, 1}),
ComputeAlignedAreaSize(4, 1));
EXPECT_EQ(ComputeTotalBytesWithAlignmentPadding({7, 7, 7, 7, 7, 7}),
ComputeAlignedAreaSize(6, 7));
EXPECT_EQ(ComputeTotalBytesWithAlignmentPadding({77, 77, 77}),
ComputeAlignedAreaSize(3, 77));
}
// Tests that UniqueAlignedArray is empty by default.
TEST(UniqueAlignedArrayTest, EmptyByDefault) {
UniqueAlignedArray array;
EXPECT_EQ(array.view().size(), 0);
EXPECT_TRUE(array.view().empty());
}
// Tests that UniqueAlignedArray::Reset() always reallocates.
TEST(UniqueAlignedArrayTest, Reset) {
UniqueAlignedArray array;
// Reset to non-empty.
array.Reset(10);
const MutableAlignedView view1 = array.view();
TF_EXPECT_OK(OkIfAligned(view1.data()));
EXPECT_EQ(view1.size(), 10);
// Calling view() again should return the same byte array.
const MutableAlignedView view2 = array.view();
ExpectSameAddress(view2.data(), view1.data());
EXPECT_EQ(view2.size(), view1.size());
// Reset to a different size.
array.Reset(33);
const MutableAlignedView view3 = array.view();
TF_EXPECT_OK(OkIfAligned(view3.data()));
EXPECT_EQ(view3.size(), 33);
}
// Tests that UniqueAlignedArray::Reset() reallocates when growing.
TEST(UniqueAlignedArrayTest, Reserve) {
UniqueAlignedArray array;
// Reset to non-empty.
array.Reserve(20);
const MutableAlignedView view1 = array.view();
TF_EXPECT_OK(OkIfAligned(view1.data()));
EXPECT_EQ(view1.size(), 20);
// Shrink to a smaller size; should not reallocate.
array.Reserve(7);
const MutableAlignedView view2 = array.view();
ExpectSameAddress(view2.data(), view1.data());
EXPECT_EQ(view2.size(), 7);
// Grow but still remain within capacity; should not reallocate.
array.Reserve(14);
const MutableAlignedView view3 = array.view();
ExpectSameAddress(view3.data(), view1.data());
EXPECT_EQ(view3.size(), 14);
}
// Tests that UniqueAlignedArray::Resize() reallocates when growing and
// preserves existing contents.
TEST(UniqueAlignedArrayTest, Resize) {
UniqueAlignedArray array;
// Resize to non-empty.
EXPECT_TRUE(array.Resize(10));
const MutableAlignedView view1 = array.view();
TF_EXPECT_OK(OkIfAligned(view1.data()));
EXPECT_EQ(view1.size(), 10);
// Write some stuff.
for (int i = 0; i < 10; ++i) view1.data()[i] = '1';
// Resize to a larger size.
EXPECT_TRUE(array.Resize(33));
const MutableAlignedView view2 = array.view();
TF_EXPECT_OK(OkIfAligned(view2.data()));
EXPECT_EQ(view2.size(), 33);
// Check that content was preserved.
for (int i = 0; i < 10; ++i) EXPECT_EQ(view2.data()[i], '1');
// Append more stuff.
for (int i = 10; i < 33; ++i) view2.data()[i] = '2';
// Resize to a smaller size.
EXPECT_FALSE(array.Resize(15));
const MutableAlignedView view3 = array.view();
TF_EXPECT_OK(OkIfAligned(view3.data()));
ExpectSameAddress(view3.data(), view2.data());
EXPECT_EQ(view3.size(), 15);
// Check that content was preserved.
for (int i = 0; i < 10; ++i) EXPECT_EQ(view3.data()[i], '1');
for (int i = 10; i < 15; ++i) EXPECT_EQ(view3.data()[i], '2');
// Overwrite with new stuff.
for (int i = 0; i < 15; ++i) view3.data()[i] = '3';
// Resize to a larger size, but still below capacity.
EXPECT_FALSE(array.Resize(20));
const MutableAlignedView view4 = array.view();
TF_EXPECT_OK(OkIfAligned(view4.data()));
ExpectSameAddress(view4.data(), view2.data());
EXPECT_EQ(view4.size(), 20);
// Check that content was preserved.
for (int i = 0; i < 15; ++i) EXPECT_EQ(view4.data()[i], '3');
}
// Tests that (Mutable)AlignedView is empty by default.
TEST(AlignedViewTest, EmptyByDefault) {
AlignedView view1;
EXPECT_EQ(view1.size(), 0);
EXPECT_TRUE(view1.empty());
MutableAlignedView view2;
EXPECT_EQ(view2.size(), 0);
EXPECT_TRUE(view2.empty());
}
// Tests that (Mutable)AlignedView::Reset() works on aligned pointers.
TEST(AlignedViewTest, ResetValid) {
char *pointer = nullptr;
pointer += 3 * internal::kAlignmentBytes;
AlignedView view1;
TF_EXPECT_OK(view1.Reset(pointer, 100));
ExpectSameAddress(view1.data(), pointer);
EXPECT_EQ(view1.size(), 100);
EXPECT_FALSE(view1.empty());
MutableAlignedView view2;
TF_EXPECT_OK(view2.Reset(pointer, 100));
ExpectSameAddress(view2.data(), pointer);
EXPECT_EQ(view2.size(), 100);
EXPECT_FALSE(view2.empty());
}
// Tests that (Mutable)AlignedView::Reset() fails on misaligned pointers.
TEST(AlignedViewTest, ResetInvalid) {
char *pointer = nullptr;
++pointer; // not aligned
AlignedView view1;
EXPECT_THAT(view1.Reset(pointer, 10),
test::IsErrorWithSubstr("Pointer fails alignment requirement"));
MutableAlignedView view2;
EXPECT_THAT(view2.Reset(pointer, 10),
test::IsErrorWithSubstr("Pointer fails alignment requirement"));
}
// Tests that (Mutable)AlignedView::Reset() can empty the view.
TEST(AlignedViewTest, ResetEmpty) {
char *pointer = nullptr;
pointer += 11 * internal::kAlignmentBytes;
// First point to a non-empty byte array.
AlignedView view1;
TF_EXPECT_OK(view1.Reset(pointer, 100));
ExpectSameAddress(view1.data(), pointer);
EXPECT_EQ(view1.size(), 100);
EXPECT_FALSE(view1.empty());
// Then reset to empty.
TF_EXPECT_OK(view1.Reset(pointer, 0));
EXPECT_EQ(view1.size(), 0);
EXPECT_TRUE(view1.empty());
// First point to a non-empty byte array.
MutableAlignedView view2;
TF_EXPECT_OK(view2.Reset(pointer, 100));
ExpectSameAddress(view2.data(), pointer);
EXPECT_EQ(view2.size(), 100);
EXPECT_FALSE(view2.empty());
// Then reset to empty.
TF_EXPECT_OK(view2.Reset(pointer, 0));
EXPECT_EQ(view2.size(), 0);
EXPECT_TRUE(view2.empty());
}
// Tests that (Mutable)AlignedView supports copy-construction and assignment
// with shallow-copy semantics, and reinterprets from char* to const char*.
TEST(AlignedViewTest, CopyAndAssign) {
char *pointer1 = nullptr;
pointer1 += 3 * internal::kAlignmentBytes;
const char *pointer2 = nullptr;
pointer2 += 7 * internal::kAlignmentBytes;
MutableAlignedView view1;
TF_ASSERT_OK(view1.Reset(pointer1, 100));
AlignedView view2;
TF_ASSERT_OK(view2.Reset(pointer2, 200));
MutableAlignedView view3(view1);
ExpectSameAddress(view3.data(), pointer1);
EXPECT_EQ(view3.size(), 100);
EXPECT_FALSE(view3.empty());
view3 = MutableAlignedView();
EXPECT_EQ(view3.size(), 0);
EXPECT_TRUE(view3.empty());
view3 = view1;
ExpectSameAddress(view3.data(), pointer1);
EXPECT_EQ(view3.size(), 100);
EXPECT_FALSE(view3.empty());
AlignedView view4(view1); // reinterprets type
ExpectSameAddress(view4.data(), pointer1);
EXPECT_EQ(view4.size(), 100);
EXPECT_FALSE(view4.empty());
view4 = AlignedView();
EXPECT_EQ(view4.size(), 0);
EXPECT_TRUE(view4.empty());
view4 = view2;
ExpectSameAddress(view4.data(), pointer2);
EXPECT_EQ(view4.size(), 200);
EXPECT_FALSE(view4.empty());
view4 = view1; // reinterprets type
ExpectSameAddress(view4.data(), pointer1);
EXPECT_EQ(view4.size(), 100);
EXPECT_FALSE(view4.empty());
view4 = MutableAlignedView(); // reinterprets type
EXPECT_EQ(view4.size(), 0);
EXPECT_TRUE(view4.empty());
}
// Tests that AlignedView can split itself into sub-views with specified sizes.
TEST(AlignedViewTest, SplitConst) {
const std::vector<size_t> sizes = {1, internal::kAlignmentBytes,
internal::kAlignmentBytes + 1, 1, 123};
const size_t total_bytes = ComputeTotalBytesWithAlignmentPadding(sizes);
AlignedView view;
TF_ASSERT_OK(view.Reset(nullptr, total_bytes));
std::vector<AlignedView> views(100); // will be resized
TF_ASSERT_OK(view.Split(sizes, &views));
ASSERT_EQ(views.size(), 5);
const char *base = view.data();
ExpectSameAddress(views[0].data(), base);
EXPECT_EQ(views[0].size(), 1);
base += internal::kAlignmentBytes;
ExpectSameAddress(views[1].data(), base);
EXPECT_EQ(views[1].size(), internal::kAlignmentBytes);
base += internal::kAlignmentBytes;
ExpectSameAddress(views[2].data(), base);
EXPECT_EQ(views[2].size(), internal::kAlignmentBytes + 1);
base += 2 * internal::kAlignmentBytes;
ExpectSameAddress(views[3].data(), base);
EXPECT_EQ(views[3].size(), 1);
base += internal::kAlignmentBytes;
ExpectSameAddress(views[4].data(), base);
EXPECT_EQ(views[4].size(), 123);
}
// Tests that MutableAlignedView can split itself into sub-views with specified
// sizes, and reinterprets from char* to const char*.
TEST(AlignedViewTest, SplitMutable) {
const std::vector<size_t> sizes = {1, internal::kAlignmentBytes,
internal::kAlignmentBytes + 1, 1, 123};
const size_t total_bytes = ComputeTotalBytesWithAlignmentPadding(sizes);
// Also add some padding to check that we can split part of the view.
MutableAlignedView view;
TF_ASSERT_OK(view.Reset(nullptr, total_bytes + 10));
std::vector<AlignedView> const_views(99); // will be resized
std::vector<MutableAlignedView> mutable_views(2); // will be resized
TF_ASSERT_OK(view.Split(sizes, &const_views));
TF_ASSERT_OK(view.Split(sizes, &mutable_views));
ASSERT_EQ(const_views.size(), 5);
ASSERT_EQ(mutable_views.size(), 5);
const char *base = view.data();
ExpectSameAddress(const_views[0].data(), base);
ExpectSameAddress(mutable_views[0].data(), base);
EXPECT_EQ(const_views[0].size(), 1);
EXPECT_EQ(mutable_views[0].size(), 1);
base += internal::kAlignmentBytes;
ExpectSameAddress(const_views[1].data(), base);
ExpectSameAddress(mutable_views[1].data(), base);
EXPECT_EQ(const_views[1].size(), internal::kAlignmentBytes);
EXPECT_EQ(mutable_views[1].size(), internal::kAlignmentBytes);
base += internal::kAlignmentBytes;
ExpectSameAddress(const_views[2].data(), base);
ExpectSameAddress(mutable_views[2].data(), base);
EXPECT_EQ(const_views[2].size(), internal::kAlignmentBytes + 1);
EXPECT_EQ(mutable_views[2].size(), internal::kAlignmentBytes + 1);
base += 2 * internal::kAlignmentBytes;
ExpectSameAddress(const_views[3].data(), base);
ExpectSameAddress(mutable_views[3].data(), base);
EXPECT_EQ(const_views[3].size(), 1);
EXPECT_EQ(mutable_views[3].size(), 1);
base += internal::kAlignmentBytes;
ExpectSameAddress(const_views[4].data(), base);
ExpectSameAddress(mutable_views[4].data(), base);
EXPECT_EQ(const_views[4].size(), 123);
EXPECT_EQ(mutable_views[4].size(), 123);
}
TEST(AlignedViewTest, SplitTooSmall) {
const std::vector<size_t> sizes = {1, internal::kAlignmentBytes,
internal::kAlignmentBytes + 1, 1, 123};
const size_t total_bytes = ComputeTotalBytesWithAlignmentPadding(sizes);
// Make the view just a bit too small.
MutableAlignedView view;
TF_ASSERT_OK(view.Reset(nullptr, total_bytes - 1));
std::vector<MutableAlignedView> views;
EXPECT_THAT(view.Split(sizes, &views),
test::IsErrorWithSubstr("View is too small to be split"));
}
// Tests that (Mutable)AlignedArea is empty by default.
TEST(AlignedAreaTest, EmptyByDefault) {
AlignedArea area1;
EXPECT_EQ(area1.num_views(), 0);
EXPECT_EQ(area1.view_size(), 0);
EXPECT_TRUE(area1.empty());
MutableAlignedArea area2;
EXPECT_EQ(area2.num_views(), 0);
EXPECT_EQ(area2.view_size(), 0);
EXPECT_TRUE(area2.empty());
}
// Tests that (Mutable)AlignedArea::Reset() can initialize to a single view.
TEST(AlignedAreaTest, ResetSingleton) {
const char *pointer1 = nullptr;
pointer1 += 3 * internal::kAlignmentBytes;
char *pointer2 = nullptr;
pointer2 += 7 * internal::kAlignmentBytes;
AlignedView view1;
TF_ASSERT_OK(view1.Reset(pointer1, internal::kAlignmentBytes));
MutableAlignedView view2;
TF_ASSERT_OK(view2.Reset(pointer2, internal::kAlignmentBytes + 1));
AlignedArea area1;
TF_ASSERT_OK(area1.Reset(view1, 1, 1));
EXPECT_EQ(area1.num_views(), 1);
EXPECT_EQ(area1.view_size(), 1);
EXPECT_FALSE(area1.empty());
ExpectSameAddress(area1.view(0).data(), pointer1);
EXPECT_EQ(area1.view(0).size(), 1);
TF_ASSERT_OK(area1.Reset(view2, 1, 2));
EXPECT_EQ(area1.num_views(), 1);
EXPECT_EQ(area1.view_size(), 2);
EXPECT_FALSE(area1.empty());
ExpectSameAddress(area1.view(0).data(), pointer2);
EXPECT_EQ(area1.view(0).size(), 2);
TF_ASSERT_OK(area1.Reset(view2, 1, 1));
EXPECT_EQ(area1.num_views(), 1);
EXPECT_EQ(area1.view_size(), 1);
EXPECT_FALSE(area1.empty());
ExpectSameAddress(area1.view(0).data(), pointer2);
EXPECT_EQ(area1.view(0).size(), 1);
MutableAlignedArea area2;
TF_ASSERT_OK(area2.Reset(view2, 1, 2));
EXPECT_EQ(area2.num_views(), 1);
EXPECT_EQ(area2.view_size(), 2);
EXPECT_FALSE(area2.empty());
ExpectSameAddress(area2.view(0).data(), pointer2);
EXPECT_EQ(area2.view(0).size(), 2);
TF_ASSERT_OK(area2.Reset(view2, 1, 1));
EXPECT_EQ(area2.num_views(), 1);
EXPECT_EQ(area2.view_size(), 1);
EXPECT_FALSE(area2.empty());
ExpectSameAddress(area2.view(0).data(), pointer2);
EXPECT_EQ(area2.view(0).size(), 1);
}
// Tests that (Mutable)AlignedArea::Reset() can initialize to a sequence of
// multiple views.
TEST(AlignedAreaTest, ResetMultiple) {
const char *pointer1 = nullptr;
pointer1 += 3 * internal::kAlignmentBytes;
char *pointer2 = nullptr;
pointer2 += 7 * internal::kAlignmentBytes;
AlignedView view1;
TF_ASSERT_OK(view1.Reset(pointer1, 11 * internal::kAlignmentBytes));
MutableAlignedView view2;
TF_ASSERT_OK(view2.Reset(pointer2, 2 * internal::kAlignmentBytes));
AlignedArea area1;
TF_ASSERT_OK(area1.Reset(view1, 11, 1));
EXPECT_EQ(area1.num_views(), 11);
EXPECT_EQ(area1.view_size(), 1);
EXPECT_FALSE(area1.empty());
for (int i = 0; i < area1.num_views(); ++i) {
ExpectSameAddress(area1.view(i).data(),
pointer1 + internal::kAlignmentBytes * i);
EXPECT_EQ(area1.view(i).size(), 1);
}
TF_ASSERT_OK(area1.Reset(view1, 10, internal::kAlignmentBytes));
EXPECT_EQ(area1.num_views(), 10);
EXPECT_EQ(area1.view_size(), internal::kAlignmentBytes);
EXPECT_FALSE(area1.empty());
for (int i = 0; i < area1.num_views(); ++i) {
ExpectSameAddress(area1.view(i).data(),
pointer1 + internal::kAlignmentBytes * i);
EXPECT_EQ(area1.view(i).size(), internal::kAlignmentBytes);
}
TF_ASSERT_OK(area1.Reset(view2, 2, 2));
EXPECT_EQ(area1.num_views(), 2);
EXPECT_EQ(area1.view_size(), 2);
EXPECT_FALSE(area1.empty());
for (int i = 0; i < area1.num_views(); ++i) {
ExpectSameAddress(area1.view(i).data(),
pointer2 + internal::kAlignmentBytes * i);
EXPECT_EQ(area1.view(i).size(), 2);
}
MutableAlignedArea area2;
TF_ASSERT_OK(area2.Reset(view2, 2, internal::kAlignmentBytes));
EXPECT_EQ(area2.num_views(), 2);
EXPECT_EQ(area2.view_size(), internal::kAlignmentBytes);
EXPECT_FALSE(area2.empty());
for (int i = 0; i < area2.num_views(); ++i) {
ExpectSameAddress(area2.view(i).data(),
pointer2 + internal::kAlignmentBytes * i);
EXPECT_EQ(area2.view(i).size(), internal::kAlignmentBytes);
}
}
// Tests that (Mutable)AlignedArea::Reset() fails when the view being split into
// sub-views is too small.
TEST(AlignedAreaTest, ResetInvalid) {
AlignedView view1;
TF_ASSERT_OK(view1.Reset(nullptr, 11 * internal::kAlignmentBytes));
MutableAlignedView view2;
TF_ASSERT_OK(view2.Reset(nullptr, 2 * internal::kAlignmentBytes));
// View size larger than available view.
AlignedArea area;
EXPECT_THAT(area.Reset(view1, 1, 11 * internal::kAlignmentBytes + 1),
test::IsErrorWithSubstr("View is too small for area"));
TF_ASSERT_OK(area.Reset(view1, 11, 1));
EXPECT_THAT(area.Reset(view2, 1, 2 * internal::kAlignmentBytes + 1),
test::IsErrorWithSubstr("View is too small for area"));
TF_ASSERT_OK(area.Reset(view2, 2, 1));
// Total size larger than available view.
EXPECT_THAT(area.Reset(view1, 12, 1),
test::IsErrorWithSubstr("View is too small for area"));
TF_ASSERT_OK(area.Reset(view1, 11, 1));
EXPECT_THAT(area.Reset(view1, 4, 2 * internal::kAlignmentBytes + 1),
test::IsErrorWithSubstr("View is too small for area"));
TF_ASSERT_OK(area.Reset(view1, 11, 1));
EXPECT_THAT(area.Reset(view1, 3, 3 * internal::kAlignmentBytes + 1),
test::IsErrorWithSubstr("View is too small for area"));
TF_ASSERT_OK(area.Reset(view1, 11, 1));
EXPECT_THAT(area.Reset(view1, 2, 5 * internal::kAlignmentBytes + 1),
test::IsErrorWithSubstr("View is too small for area"));
TF_ASSERT_OK(area.Reset(view1, 11, 1));
EXPECT_THAT(area.Reset(view2, 3, 1),
test::IsErrorWithSubstr("View is too small for area"));
TF_ASSERT_OK(area.Reset(view2, 2, 1));
EXPECT_THAT(area.Reset(view2, 2, internal::kAlignmentBytes + 1),
test::IsErrorWithSubstr("View is too small for area"));
TF_ASSERT_OK(area.Reset(view2, 2, 1));
}
// Tests that (Mutable)AlignedView::Reset() can empty the area.
TEST(AlignedAreaTest, ResetEmpty) {
AlignedView view1;
TF_ASSERT_OK(view1.Reset(nullptr, 11 * internal::kAlignmentBytes));
MutableAlignedView view2;
TF_ASSERT_OK(view2.Reset(nullptr, 2 * internal::kAlignmentBytes));
// First point to a non-empty byte array, then clear.
AlignedArea area1;
TF_ASSERT_OK(area1.Reset(view1, 11, 1));
TF_ASSERT_OK(area1.Reset(view1, 0, 0));
EXPECT_EQ(area1.num_views(), 0);
EXPECT_EQ(area1.view_size(), 0);
EXPECT_TRUE(area1.empty());
TF_ASSERT_OK(area1.Reset(view2, 2, 1));
TF_ASSERT_OK(area1.Reset(view2, 0, 100));
EXPECT_EQ(area1.num_views(), 0);
EXPECT_EQ(area1.view_size(), 100);
EXPECT_TRUE(area1.empty());
TF_ASSERT_OK(area1.Reset(view2, 2, 1));
TF_ASSERT_OK(area1.Reset(MutableAlignedView(), 0, 1));
EXPECT_EQ(area1.num_views(), 0);
EXPECT_EQ(area1.view_size(), 1);
EXPECT_TRUE(area1.empty());
MutableAlignedArea area2;
TF_ASSERT_OK(area2.Reset(view2, 2, 1));
TF_ASSERT_OK(area2.Reset(view2, 0, 0));
EXPECT_EQ(area2.num_views(), 0);
EXPECT_EQ(area2.view_size(), 0);
EXPECT_TRUE(area2.empty());
TF_ASSERT_OK(area2.Reset(view2, 2, 1));
TF_ASSERT_OK(area2.Reset(view2, 0, 100));
EXPECT_EQ(area2.num_views(), 0);
EXPECT_EQ(area2.view_size(), 100);
EXPECT_TRUE(area2.empty());
TF_ASSERT_OK(area2.Reset(view2, 2, 1));
TF_ASSERT_OK(area2.Reset(MutableAlignedView(), 0, 1));
EXPECT_EQ(area2.num_views(), 0);
EXPECT_EQ(area2.view_size(), 1);
EXPECT_TRUE(area2.empty());
}
// Tests that (Mutable)AlignedArea supports copy-construction and assignment
// with shallow-copy semantics, and reinterprets from char* to const char*.
TEST(AlignedAreaTest, CopyAndAssign) {
char *pointer1 = nullptr;
pointer1 += 3 * internal::kAlignmentBytes;
const char *pointer2 = nullptr;
pointer2 += 7 * internal::kAlignmentBytes;
MutableAlignedView view1;
TF_ASSERT_OK(view1.Reset(pointer1, ComputeAlignedAreaSize(1, 5)));
AlignedView view2;
TF_ASSERT_OK(view2.Reset(pointer2, ComputeAlignedAreaSize(2, 77)));
MutableAlignedArea area1;
TF_ASSERT_OK(area1.Reset(view1, 1, 5));
AlignedArea area2;
TF_ASSERT_OK(area2.Reset(view2, 2, 77));
MutableAlignedArea area3(area1);
EXPECT_EQ(area3.num_views(), 1);
EXPECT_EQ(area3.view_size(), 5);
EXPECT_FALSE(area3.empty());
ExpectSameAddress(area3.view(0).data(), pointer1);
EXPECT_EQ(area3.view(0).size(), 5);
area3 = MutableAlignedArea();
EXPECT_EQ(area3.num_views(), 0);
EXPECT_EQ(area3.view_size(), 0);
EXPECT_TRUE(area3.empty());
area3 = area1;
EXPECT_EQ(area3.num_views(), 1);
EXPECT_EQ(area3.view_size(), 5);
EXPECT_FALSE(area3.empty());
ExpectSameAddress(area3.view(0).data(), pointer1);
EXPECT_EQ(area3.view(0).size(), 5);
AlignedArea area4(area1); // reinterprets type
EXPECT_EQ(area4.num_views(), 1);
EXPECT_EQ(area4.view_size(), 5);
EXPECT_FALSE(area4.empty());
ExpectSameAddress(area4.view(0).data(), pointer1);
EXPECT_EQ(area4.view(0).size(), 5);
area4 = AlignedArea();
EXPECT_EQ(area4.num_views(), 0);
EXPECT_EQ(area4.view_size(), 0);
EXPECT_TRUE(area4.empty());
area4 = area2;
EXPECT_EQ(area4.num_views(), 2);
EXPECT_EQ(area4.view_size(), 77);
EXPECT_FALSE(area4.empty());
ExpectSameAddress(area4.view(0).data(), pointer2);
EXPECT_EQ(area4.view(0).size(), 77);
ExpectSameAddress(area4.view(1).data(), PadToAlignment(pointer2 + 77));
EXPECT_EQ(area4.view(1).size(), 77);
area4 = area1; // reinterprets type
EXPECT_EQ(area4.num_views(), 1);
EXPECT_EQ(area4.view_size(), 5);
EXPECT_FALSE(area4.empty());
ExpectSameAddress(area4.view(0).data(), pointer1);
EXPECT_EQ(area4.view(0).size(), 5);
area4 = MutableAlignedArea(); // reinterprets type
EXPECT_EQ(area4.num_views(), 0);
EXPECT_EQ(area4.view_size(), 0);
EXPECT_TRUE(area4.empty());
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/array_variable_store.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/cpu_info.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Increment this if the serialized format changes in an incompatible way that
// can't be detected through other means. For example,
// * If kAlignmentBytes is changed, then kVersion need not change because there
// is a separate field for detecting alignment mismatch.
// * If ArrayVariableStoreSpec.variable is no longer populated, perhaps replaced
// by some other approach, then kVersion should be incremented.
const uint32 ArrayVariableStore::kVersion = 0;
tensorflow::Status ArrayVariableStore::Reset(const ArrayVariableStoreSpec &spec,
AlignedView data) {
if (!spec.has_version() || !spec.has_alignment_bytes() ||
!spec.has_is_little_endian()) {
return tensorflow::errors::InvalidArgument(
"ArrayVariableStoreSpec is missing a required field: ",
spec.ShortDebugString());
}
if (spec.version() != kVersion) {
return tensorflow::errors::InvalidArgument(
"ArrayVariableStoreSpec.version (", spec.version(),
") does not match the binary (", kVersion, ")");
}
if (spec.alignment_bytes() != internal::kAlignmentBytes) {
return tensorflow::errors::InvalidArgument(
"ArrayVariableStoreSpec.alignment_bytes (", spec.alignment_bytes(),
") does not match the binary (", internal::kAlignmentBytes, ")");
}
// TODO(googleuser): It should be possible to correct an endian-ness mismatch.
// A rough outline is:
// * VariableStore::Lookup() takes an additional argument set to sizeof(T).
// * Capture sizeof(T) and write it into the VariableSpec.
// * Detect endian mismatch and byte-swap variables with multi-byte types.
if (spec.is_little_endian() != tensorflow::port::kLittleEndian) {
return tensorflow::errors::InvalidArgument(
"ArrayVariableStoreSpec.is_little_endian (", spec.is_little_endian(),
") does not match the binary (", tensorflow::port::kLittleEndian, ")");
}
for (const VariableSpec &variable_spec : spec.variable()) {
// When the proto parser encounters an unknown enumerator on the wire, it
// replaces it with the default value (i.e., FORMAT_UNKNOWN). Therefore,
// VariableSpec.format() will always return a valid enumerator.
DCHECK(VariableSpec::Format_IsValid(variable_spec.format()));
if (variable_spec.format() == VariableSpec::FORMAT_UNKNOWN) {
return tensorflow::errors::InvalidArgument(
"Unknown variable format: ", variable_spec.ShortDebugString());
}
if (variable_spec.format() == VariableSpec::FORMAT_FLAT &&
variable_spec.num_views() != 1) {
return tensorflow::errors::InvalidArgument(
"Flat variables must have 1 view: ",
variable_spec.ShortDebugString());
}
}
// Build into a temp mapping to avoid modification on error.
std::unique_ptr<std::map<Key, Value>> new_variables(
new std::map<Key, Value>());
// Slice sub-arrays off of the main byte array.
const char *base = data.data();
const char *const end = base + data.size();
for (const VariableSpec &variable_spec : spec.variable()) {
const size_t num_views = variable_spec.num_views();
const size_t view_size = variable_spec.view_size();
const size_t area_size = ComputeAlignedAreaSize(num_views, view_size);
if (base + area_size > end) {
return tensorflow::errors::InvalidArgument(
"Variable would overrun main byte array: ",
variable_spec.ShortDebugString());
}
AlignedView view;
TF_RETURN_IF_ERROR(view.Reset(base, area_size));
base += area_size; // remove claimed slice
// Set dimensions from the spec.
std::vector<size_t> dimensions(variable_spec.dimension().begin(),
variable_spec.dimension().end());
Value value(std::move(dimensions), AlignedArea());
AlignedArea &area = value.second;
TF_RETURN_IF_ERROR(area.Reset(view, num_views, view_size));
// Currently, blocked variables are meant for fast inference algorithms,
// which do not tolerate padding. Raise errors if there is padding.
if (variable_spec.format() ==
VariableSpec::FORMAT_COLUMN_BLOCKED_ROW_MAJOR_MATRIX) {
size_t padding = variable_spec.view_size() % internal::kAlignmentBytes;
if (padding != 0) {
return tensorflow::errors::Internal(
"Currently, fast matrix-vector operations do not support padded "
"blocked matrices, but variable '",
variable_spec.name(), "' has padding ", padding);
}
}
const Key key(variable_spec.name(), variable_spec.format());
if (!new_variables->emplace(key, value).second) {
return tensorflow::errors::InvalidArgument(
"Duplicate variable: ", variable_spec.ShortDebugString());
}
}
if (base != end) {
return tensorflow::errors::InvalidArgument(
"Variables do not completely cover main byte array: ", end - base,
" bytes remaining");
}
// Success; make modifications.
variables_ = std::move(new_variables);
return tensorflow::Status::OK();
}
tensorflow::Status ArrayVariableStore::Lookup(const string &name,
VariableSpec::Format format,
std::vector<size_t> *dimensions,
AlignedArea *area) {
if (!variables_) {
return tensorflow::errors::FailedPrecondition(
"ArrayVariableStore not initialized");
}
const Key key(name, format);
const auto it = variables_->find(key);
if (it == variables_->end()) {
return tensorflow::errors::NotFound(
"ArrayVariableStore has no variable with name '", name, "' and format ",
VariableSpec::Format_Name(format));
}
// Success; make modifications.
const Value &value = it->second;
*dimensions = value.first;
*area = value.second;
return tensorflow::Status::OK();
}
tensorflow::Status ArrayVariableStore::Close() {
if (!variables_) {
return tensorflow::errors::FailedPrecondition(
"ArrayVariableStore not initialized");
}
variables_.reset();
return tensorflow::Status::OK();
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_ARRAY_VARIABLE_STORE_H_
#define DRAGNN_RUNTIME_ARRAY_VARIABLE_STORE_H_
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "dragnn/protos/runtime.pb.h"
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// A variable store that groups all variables into a single byte array. This
// class and its subclasses are intended for use in production.
//
// Each variable occupies a sub-array of the main byte array. The mapping from
// the name and format of a variable to the sub-array containing its content is
// defined in ArrayVariableStoreSpec. The variables may appear in any order.
//
// This format allows variables to be mapped directly into memory, which reduces
// initialization time and supports usage on-device, where mmap() is effectively
// obligatory for large data resources.
class ArrayVariableStore : public VariableStore {
public:
// Creates an uninitialized store.
ArrayVariableStore() = default;
// Resets this to represent the variables defined by the |spec| and |data|.
// The |data| must remain valid until this is destroyed or Reset(). (Note
// that subclasses have simpler lifetime requirements). On error, returns
// non-OK and modifies nothing.
tensorflow::Status Reset(const ArrayVariableStoreSpec &spec,
AlignedView data);
// Implements VariableStore.
using VariableStore::Lookup; // import Lookup<T>() convenience methods
tensorflow::Status Lookup(const string &name, VariableSpec::Format format,
std::vector<size_t> *dimensions,
AlignedArea *area) override;
tensorflow::Status Close() override;
private:
friend class ArrayVariableStoreBuilder; // for access to kVersion
// The current version of the serialized format.
static const uint32 kVersion;
// A (name,format) key associated with a variable.
using Key = std::pair<string, VariableSpec::Format>;
// Dimension vector and aligned area.
using Value = std::pair<const std::vector<size_t>, AlignedArea>;
// Mapping from variable key to variable content. Initially null, filled in
// Reset(), and deleted in Close(). Wrapped in std::unique_ptr so the entire
// mapping can be deleted.
std::unique_ptr<std::map<Key, Value>> variables_;
};
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_ARRAY_VARIABLE_STORE_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/array_variable_store_builder.h"
#include <stddef.h>
#include <tuple>
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/array_variable_store.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/cpu_info.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Appends the content of the |view| to the |data|, followed by zero-padding to
// the next alignment boundary.
void Append(AlignedView view, string *data) {
DCHECK_EQ(PadToAlignment(data->size()), data->size());
const size_t alignment_padding = PadToAlignment(view.size()) - view.size();
data->append(view.data(), view.size());
data->append(alignment_padding, '\0');
}
// As above, but for an aligned |area|.
void Append(AlignedArea area, string *data) {
DCHECK_EQ(PadToAlignment(data->size()), data->size());
const size_t orig_size = data->size();
for (size_t i = 0; i < area.num_views(); ++i) Append(area.view(i), data);
DCHECK_EQ(data->size() - orig_size,
ComputeAlignedAreaSize(area.num_views(), area.view_size()));
}
} // namespace
tensorflow::Status ArrayVariableStoreBuilder::Build(
const Variables &variables, ArrayVariableStoreSpec *spec, string *data) {
data->clear();
spec->Clear();
spec->set_version(ArrayVariableStore::kVersion);
spec->set_alignment_bytes(internal::kAlignmentBytes);
spec->set_is_little_endian(tensorflow::port::kLittleEndian);
for (const auto &variable : variables) {
string name;
VariableSpec::Format format;
std::vector<size_t> dimensions;
AlignedArea area;
std::tie(name, format) = variable.first;
std::tie(dimensions, area) = variable.second;
if (format == VariableSpec::FORMAT_FLAT && area.num_views() != 1) {
return tensorflow::errors::InvalidArgument(
"Flat variables must have 1 view, but '", name, "' has ",
area.num_views());
}
VariableSpec *variable_spec = spec->add_variable();
variable_spec->set_name(name);
variable_spec->set_format(format);
variable_spec->set_num_views(area.num_views());
variable_spec->set_view_size(area.view_size());
for (size_t dimension : dimensions) {
variable_spec->add_dimension(dimension);
}
Append(area, data);
}
return tensorflow::Status::OK();
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_ARRAY_VARIABLE_STORE_BUILDER_H_
#define DRAGNN_RUNTIME_ARRAY_VARIABLE_STORE_BUILDER_H_
#include <map>
#include <string>
#include "dragnn/protos/runtime.pb.h"
#include "dragnn/runtime/variable_store_wrappers.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Utils for converting a set of variables into a byte array that can be loaded
// by ArrayVariableStore. See that class for details on the required format.
class ArrayVariableStoreBuilder {
public:
using Variables = CaptureUsedVariableStoreWrapper::Variables;
// Forbids instantiation; pure static class.
ArrayVariableStoreBuilder() = delete;
~ArrayVariableStoreBuilder() = delete;
// Overwrites the |data| with a byte array that represents the |variables|,
// and overwrites the |spec| with the associated configuration. On error,
// returns non-OK.
static tensorflow::Status Build(const Variables &variables,
ArrayVariableStoreSpec *spec, string *data);
};
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_ARRAY_VARIABLE_STORE_BUILDER_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/array_variable_store_builder.h"
#include <stddef.h>
#include <map>
#include <string>
#include <vector>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/runtime.pb.h"
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/test/helpers.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Tests that the builder rejects invalid flat variables.
TEST(ArrayVariableStoreBuilderTest, InvalidFlatVariable) {
AlignedView view;
ArrayVariableStoreBuilder::Variables variables;
ArrayVariableStoreSpec spec;
string data;
TF_ASSERT_OK(view.Reset(nullptr, 2 * internal::kAlignmentBytes));
// Try an empty area.
std::pair<string, VariableSpec::Format> foo_key("foo",
VariableSpec::FORMAT_FLAT);
AlignedArea area;
TF_ASSERT_OK(area.Reset(view, 0, 0));
std::pair<std::vector<size_t>, AlignedArea> foo_value({1}, area);
variables.push_back(std::make_pair(foo_key, foo_value));
EXPECT_THAT(ArrayVariableStoreBuilder::Build(variables, &spec, &data),
test::IsErrorWithSubstr(
"Flat variables must have 1 view, but 'foo' has 0"));
// Try an area with more than 1 sub-view.
TF_ASSERT_OK(area.Reset(view, 2, 0));
variables[0].second.second = area;
EXPECT_THAT(ArrayVariableStoreBuilder::Build(variables, &spec, &data),
test::IsErrorWithSubstr(
"Flat variables must have 1 view, but 'foo' has 2"));
}
// Tests that the builder succeeds on good inputs and reproduces an expected
// byte array.
//
// NB: Since this test directly compares the byte array, it implicitly requires
// that the builder lays out the variables in a particular order. If that order
// changes, the test expectations must be updated.
TEST(ArrayVariableStoreBuilderTest, RegressionTest) {
const string kLocalSpecPath =
"dragnn/runtime/testdata/array_variable_store_spec";
const string kLocalDataPath =
"dragnn/runtime/testdata/array_variable_store_data";
const string kExpectedSpecPath = tensorflow::io::JoinPath(
test::GetTestDataPrefix(),
"dragnn/runtime/testdata/array_variable_store_spec");
const string kExpectedDataPath = tensorflow::io::JoinPath(
test::GetTestDataPrefix(),
"dragnn/runtime/testdata/array_variable_store_data");
// If these values are changed, make sure to rewrite the test data and update
// array_variable_store_test.cc.
UniqueMatrix<float> foo({{0.0, 0.5, 1.0}, //
{1.5, 2.0, 2.5}, //
{3.0, 3.5, 4.0}, //
{4.5, 5.0, 5.5}});
UniqueMatrix<double> baz_data({{1.0, 2.0, 2.0, 2.0}, //
{3.0, 4.0, 4.0, 4.0}, //
{5.0, 6.0, 6.0, 6.0}, //
{7.0, 8.0, 8.0, 8.0}});
ArrayVariableStoreBuilder::Variables variables;
std::pair<string, VariableSpec::Format> foo_key(
"foo", VariableSpec::FORMAT_ROW_MAJOR_MATRIX);
std::pair<std::vector<size_t>, AlignedArea> foo_value(
{foo->num_rows(), foo->num_columns()}, AlignedArea(foo.area()));
variables.push_back(std::make_pair(foo_key, foo_value));
std::pair<string, VariableSpec::Format> baz_key(
"baz", VariableSpec::FORMAT_COLUMN_BLOCKED_ROW_MAJOR_MATRIX);
std::pair<std::vector<size_t>, AlignedArea> baz_value(
{2, 8, 4}, AlignedArea(baz_data.area()));
variables.push_back(std::make_pair(baz_key, baz_value));
ArrayVariableStoreSpec actual_spec;
actual_spec.set_version(999);
string actual_data = "garbage to be overwritten";
TF_ASSERT_OK(
ArrayVariableStoreBuilder::Build(variables, &actual_spec, &actual_data));
if (false) {
// Rewrite the test data.
TF_CHECK_OK(tensorflow::WriteTextProto(tensorflow::Env::Default(),
kLocalSpecPath, actual_spec));
TF_CHECK_OK(tensorflow::WriteStringToFile(tensorflow::Env::Default(),
kLocalDataPath, actual_data));
} else {
// Compare to the test data.
ArrayVariableStoreSpec expected_spec;
string expected_data;
TF_CHECK_OK(tensorflow::ReadTextProto(tensorflow::Env::Default(),
kExpectedSpecPath, &expected_spec));
TF_CHECK_OK(tensorflow::ReadFileToString(
tensorflow::Env::Default(), kExpectedDataPath, &expected_data));
EXPECT_THAT(actual_spec, test::EqualsProto(expected_spec));
EXPECT_EQ(actual_data, expected_data);
}
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/array_variable_store.h"
#include <string.h>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/runtime.pb.h"
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/file_array_variable_store.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/mmap_array_variable_store.h"
#include "dragnn/runtime/test/helpers.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
template <class T>
void ExpectBlockedData(BlockedMatrix<T> matrix,
const std::vector<std::vector<T>> &data) {
EXPECT_EQ(matrix.num_vectors(), data.size());
// The indices don't really have semantic names, so we just use `i` and `j`.
// See BlockedMatrixFormat for details.
for (int i = 0; i < matrix.num_vectors(); ++i) {
EXPECT_EQ(matrix.block_size(), data[i].size());
for (int j = 0; j < data[i].size(); ++j) {
EXPECT_EQ(matrix.vector(i)[j], data[i][j]);
}
}
}
// Returns an ArrayVariableStoreSpec parsed from the |text|.
ArrayVariableStoreSpec MakeSpec(const string &text) {
ArrayVariableStoreSpec spec;
CHECK(TextFormat::ParseFromString(text, &spec));
return spec;
}
// Returns an ArrayVariableStoreSpec that has proper top-level settings and
// whose variables are parsed from the |variables_text|.
ArrayVariableStoreSpec MakeSpecWithVariables(const string &variables_text) {
return MakeSpec(tensorflow::strings::StrCat(
"version: 0 alignment_bytes: ", internal::kAlignmentBytes,
" is_little_endian: ", tensorflow::port::kLittleEndian, " ",
variables_text));
}
// Tests that kLittleEndian actually means little-endian.
TEST(ArrayVariableStoreTest, EndianDetection) {
static_assert(sizeof(uint32) == 4 * sizeof(uint8), "Unexpected int sizes");
const uint32 foo = 0xdeadbeef;
uint8 foo_bytes[4];
memcpy(foo_bytes, &foo, 4 * sizeof(uint8));
if (tensorflow::port::kLittleEndian) {
EXPECT_EQ(foo_bytes[3], 0xde);
EXPECT_EQ(foo_bytes[2], 0xad);
EXPECT_EQ(foo_bytes[1], 0xbe);
EXPECT_EQ(foo_bytes[0], 0xef);
} else {
EXPECT_EQ(foo_bytes[0], 0xde);
EXPECT_EQ(foo_bytes[1], 0xad);
EXPECT_EQ(foo_bytes[2], 0xbe);
EXPECT_EQ(foo_bytes[3], 0xef);
}
}
// Tests that the store checks for missing fields.
TEST(ArrayVariableStoreTest, MissingRequiredField) {
for (const string kSpec :
{"version: 0 alignment_bytes: 0", "version: 0 is_little_endian: true",
"alignment_bytes: 0 is_little_endian: true"}) {
ArrayVariableStore store;
EXPECT_THAT(store.Reset(MakeSpec(kSpec), AlignedView()),
test::IsErrorWithSubstr(
"ArrayVariableStoreSpec is missing a required field"));
}
}
// Tests that the store checks for a matching version number.
TEST(ArrayVariableStoreTest, VersionMismatch) {
const string kSpec = "version: 999 alignment_bytes: 0 is_little_endian: true";
ArrayVariableStore store;
EXPECT_THAT(store.Reset(MakeSpec(kSpec), AlignedView()),
test::IsErrorWithSubstr("ArrayVariableStoreSpec.version (999) "
"does not match the binary (0)"));
}
// Tests that the store checks for a matching alignment requirement.
TEST(ArrayVariableStoreTest, AlignmentMismatch) {
const string kSpec = "version: 0 alignment_bytes: 1 is_little_endian: true";
ArrayVariableStore store;
EXPECT_THAT(store.Reset(MakeSpec(kSpec), AlignedView()),
test::IsErrorWithSubstr(tensorflow::strings::StrCat(
"ArrayVariableStoreSpec.alignment_bytes (1) does not match "
"the binary (", internal::kAlignmentBytes, ")")));
}
// Tests that the store checks for matching endian-ness.
TEST(ArrayVariableStoreTest, EndiannessMismatch) {
const string kSpec = tensorflow::strings::StrCat(
"version: 0 alignment_bytes: ", internal::kAlignmentBytes,
" is_little_endian: ", !tensorflow::port::kLittleEndian);
ArrayVariableStore store;
EXPECT_THAT(
store.Reset(MakeSpec(kSpec), AlignedView()),
test::IsErrorWithSubstr(tensorflow::strings::StrCat(
"ArrayVariableStoreSpec.is_little_endian (",
!tensorflow::port::kLittleEndian, ") does not match the binary (",
tensorflow::port::kLittleEndian, ")")));
}
// Tests that the store rejects FORMAT_UNKNOWN variables.
TEST(ArrayVariableStoreTest, RejectFormatUnknown) {
const string kVariables = "variable { format: FORMAT_UNKNOWN }";
ArrayVariableStore store;
EXPECT_THAT(store.Reset(MakeSpecWithVariables(kVariables), AlignedView()),
test::IsErrorWithSubstr("Unknown variable format"));
}
// Tests that the store rejects FORMAT_FLAT variables with too few sub-views.
TEST(ArrayVariableStoreTest, TooFewViewsForFlatVariable) {
const string kVariables = "variable { format: FORMAT_FLAT num_views: 0 }";
ArrayVariableStore store;
EXPECT_THAT(
store.Reset(MakeSpecWithVariables(kVariables), AlignedView()),
test::IsErrorWithSubstr("Flat variables must have 1 view"));
}
// Tests that the store rejects FORMAT_FLAT variables with too many sub-views.
TEST(ArrayVariableStoreTest, TooManyViewsForFlatVariable) {
const string kVariables = "variable { format: FORMAT_FLAT num_views: 2 }";
ArrayVariableStore store;
EXPECT_THAT(
store.Reset(MakeSpecWithVariables(kVariables), AlignedView()),
test::IsErrorWithSubstr("Flat variables must have 1 view"));
}
// Tests that the store accepts FORMAT_ROW_MAJOR_MATRIX variables with one
// sub-view.
TEST(ArrayVariableStoreTest, MatrixWithOneRow) {
const string kVariables =
"variable { format: FORMAT_ROW_MAJOR_MATRIX num_views: 1 view_size: 0 }";
ArrayVariableStore store;
TF_EXPECT_OK(store.Reset(MakeSpecWithVariables(kVariables), AlignedView()));
}
// Tests that the store rejects variables that overrun the main byte array.
TEST(ArrayVariableStoreTest, VariableOverrunsMainByteArray) {
const string kVariables =
"variable { format: FORMAT_FLAT num_views: 1 view_size: 1024 }";
AlignedView data;
TF_ASSERT_OK(data.Reset(nullptr, 1023));
ArrayVariableStore store;
EXPECT_THAT(
store.Reset(MakeSpecWithVariables(kVariables), data),
test::IsErrorWithSubstr("Variable would overrun main byte array"));
}
// Tests that the store rejects duplicate variables.
TEST(ArrayVariableStoreTest, DuplicateVariables) {
const string kVariables = R"(
variable { name: 'x' format: FORMAT_FLAT num_views: 1 view_size: 1024 }
variable { name: 'y' format: FORMAT_FLAT num_views: 1 view_size: 2048 }
variable { name: 'x' format: FORMAT_FLAT num_views: 1 view_size: 4096 }
)";
AlignedView data;
TF_ASSERT_OK(data.Reset(nullptr, 1 << 20)); // 1MB
ArrayVariableStore store;
EXPECT_THAT(store.Reset(MakeSpecWithVariables(kVariables), data),
test::IsErrorWithSubstr("Duplicate variable"));
}
// Tests that the store rejects sets of variables that do not completely cover
// the main byte array.
TEST(ArrayVariableStoreTest, LeftoverBytesInMainByteArray) {
const string kVariables = R"(
variable { name: 'x' format: FORMAT_FLAT num_views: 1 view_size: 1024 }
variable { name: 'y' format: FORMAT_FLAT num_views: 1 view_size: 2048 }
variable { name: 'z' format: FORMAT_FLAT num_views: 1 view_size: 4096 }
)";
AlignedView data;
TF_ASSERT_OK(data.Reset(nullptr, 1 << 20)); // 1MB
ArrayVariableStore store;
EXPECT_THAT(store.Reset(MakeSpecWithVariables(kVariables), data),
test::IsErrorWithSubstr(
"Variables do not completely cover main byte array"));
}
// The fast matrix-vector routines do not support padding.
TEST(ArrayVariableStoreTest, PaddingInBlockedMatrix) {
const string kVariables = R"(
variable {
name: "baz"
format: FORMAT_COLUMN_BLOCKED_ROW_MAJOR_MATRIX
num_views: 4
view_size: 16
dimension: 2
dimension: 4
dimension: 2
}
)";
AlignedView data;
TF_ASSERT_OK(data.Reset(nullptr, 1 << 20)); // 1MB
ArrayVariableStore store;
EXPECT_THAT(store.Reset(MakeSpecWithVariables(kVariables), data),
test::IsErrorWithSubstr(
"Currently, fast matrix-vector operations do not support "
"padded blocked matrices"));
}
// Tests that the store cannot retrieve variables when it is uninitialized.
TEST(ArrayVariableStoreTest, LookupWhenUninitialized) {
ArrayVariableStore store;
Vector<float> vector;
EXPECT_THAT(store.Lookup("foo", &vector),
test::IsErrorWithSubstr("ArrayVariableStore not initialized"));
}
// Tests that the store can use an empty byte array when there are no variables.
TEST(ArrayVariableStoreTest, EmptyByteArrayWorksIfNoVariables) {
ArrayVariableStore store;
TF_EXPECT_OK(store.Reset(MakeSpecWithVariables(""), AlignedView()));
// The store contains nothing.
Vector<float> vector;
EXPECT_THAT(
store.Lookup("foo", &vector),
test::IsErrorWithSubstr("ArrayVariableStore has no variable with name "
"'foo' and format FORMAT_FLAT"));
}
// Tests that the store fails if it is closed before it has been initialized.
TEST(ArrayVariableStoreTest, CloseBeforeReset) {
ArrayVariableStore store;
EXPECT_THAT(store.Close(),
test::IsErrorWithSubstr("ArrayVariableStore not initialized"));
}
// Tests that the store can be closed (once) after it has been initialized.
TEST(ArrayVariableStoreTest, CloseAfterReset) {
ArrayVariableStore store;
TF_ASSERT_OK(store.Reset(MakeSpecWithVariables(""), AlignedView()));
TF_EXPECT_OK(store.Close());
// Closing twice is still an error.
EXPECT_THAT(store.Close(),
test::IsErrorWithSubstr("ArrayVariableStore not initialized"));
}
// Templated on an ArrayVariableStore subclass.
template <class Subclass>
class ArrayVariableStoreSubclassTest : public ::testing::Test {};
typedef ::testing::Types<FileArrayVariableStore, MmapArrayVariableStore>
Subclasses;
TYPED_TEST_CASE(ArrayVariableStoreSubclassTest, Subclasses);
// Tests that the store fails to load a non-existent file.
TYPED_TEST(ArrayVariableStoreSubclassTest, NonExistentFile) {
// Paths to the spec and data produced by array_variable_store_builder_test.
const string kDataPath = tensorflow::io::JoinPath(
test::GetTestDataPrefix(), "dragnn/runtime/testdata/non_existent_file");
TypeParam store;
EXPECT_THAT(store.Reset(MakeSpecWithVariables(""), kDataPath),
test::IsErrorWithSubstr(""));
}
// Tests that the store can load an empty file if there are no variables.
TYPED_TEST(ArrayVariableStoreSubclassTest, EmptyFile) {
// Paths to the spec and data produced by array_variable_store_builder_test.
const string kDataPath = tensorflow::io::JoinPath(
test::GetTestDataPrefix(), "dragnn/runtime/testdata/empty_file");
TypeParam store;
TF_ASSERT_OK(store.Reset(MakeSpecWithVariables(""), kDataPath));
Vector<float> vector;
Matrix<float> row_major_matrix;
EXPECT_THAT(store.Lookup("foo", &vector),
test::IsErrorWithSubstr("ArrayVariableStore has no variable with "
"name 'foo' and format FORMAT_FLAT"));
EXPECT_THAT(
store.Lookup("bar", &row_major_matrix),
test::IsErrorWithSubstr("ArrayVariableStore has no variable with name "
"'bar' and format FORMAT_ROW_MAJOR_MATRIX"));
}
// Tests that the store, when loading a pre-built byte array, produces the same
// variables that the builder converted.
TYPED_TEST(ArrayVariableStoreSubclassTest, RegressionTest) {
// Paths to the spec and data produced by array_variable_store_builder_test.
const string kSpecPath = tensorflow::io::JoinPath(
test::GetTestDataPrefix(),
"dragnn/runtime/testdata/array_variable_store_spec");
const string kDataPath = tensorflow::io::JoinPath(
test::GetTestDataPrefix(),
"dragnn/runtime/testdata/array_variable_store_data");
ArrayVariableStoreSpec spec;
TF_CHECK_OK(
tensorflow::ReadTextProto(tensorflow::Env::Default(), kSpecPath, &spec));
TypeParam store;
TF_ASSERT_OK(store.Reset(spec, kDataPath));
Matrix<float> foo;
TF_ASSERT_OK(store.Lookup("foo", &foo));
// NB: These assertions must be kept in sync with the variables defined in
// array_variable_store_builder_test.cc.
ExpectMatrix(foo, {{0.0, 0.5, 1.0}, //
{1.5, 2.0, 2.5}, //
{3.0, 3.5, 4.0}, //
{4.5, 5.0, 5.5}});
// Blocked formats.
BlockedMatrix<double> baz;
TF_ASSERT_OK(store.Lookup("baz", &baz));
EXPECT_EQ(baz.num_rows(), 2);
EXPECT_EQ(baz.num_columns(), 8);
EXPECT_EQ(baz.block_size(), 4);
ExpectBlockedData(baz, {{1.0, 2.0, 2.0, 2.0}, //
{3.0, 4.0, 4.0, 4.0}, //
{5.0, 6.0, 6.0, 6.0}, //
{7.0, 8.0, 8.0, 8.0}});
// Try versions of "foo" and "baz" with the wrong format.
Vector<float> vector;
Matrix<float> row_major_matrix;
EXPECT_THAT(store.Lookup("foo", &vector),
test::IsErrorWithSubstr("ArrayVariableStore has no variable with "
"name 'foo' and format FORMAT_FLAT"));
EXPECT_THAT(store.Lookup("baz", &vector),
test::IsErrorWithSubstr("ArrayVariableStore has no variable with "
"name 'baz' and format FORMAT_FLAT"));
EXPECT_THAT(
store.Lookup("baz", &row_major_matrix),
test::IsErrorWithSubstr("ArrayVariableStore has no variable with name "
"'baz' and format FORMAT_ROW_MAJOR_MATRIX"));
// Try totally unknown variables.
EXPECT_THAT(store.Lookup("missing", &vector),
test::IsErrorWithSubstr("ArrayVariableStore has no variable with "
"name 'missing' and format FORMAT_FLAT"));
EXPECT_THAT(
store.Lookup("missing", &row_major_matrix),
test::IsErrorWithSubstr("ArrayVariableStore has no variable with name "
"'missing' and format FORMAT_ROW_MAJOR_MATRIX"));
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/attributes.h"
#include <set>
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/platform/logging.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
tensorflow::Status Attributes::Reset(
const tensorflow::protobuf::Map<string, string> &mapping) {
// First pass: Parse each value in the |mapping|.
for (const auto &name_value : mapping) {
const string &name = name_value.first;
const string &value = name_value.second;
const auto it = attributes_.find(name);
if (it == attributes_.end()) {
return tensorflow::errors::InvalidArgument("Unknown attribute: ", name);
}
TF_RETURN_IF_ERROR(it->second->Parse(value));
}
// Second pass: Look for missing mandatory attributes.
std::set<string> missing_mandatory_attributes;
for (const auto &it : attributes_) {
const string &name = it.first;
Attribute *attribute = it.second;
if (!attribute->IsMandatory()) continue;
if (mapping.find(name) == mapping.end()) {
missing_mandatory_attributes.insert(name);
}
}
if (!missing_mandatory_attributes.empty()) {
return tensorflow::errors::InvalidArgument(
"Missing mandatory attributes: ",
tensorflow::str_util::Join(missing_mandatory_attributes, " "));
}
return tensorflow::Status::OK();
}
void Attributes::Register(const string &name, Attribute *attribute) {
const bool unique = attributes_.emplace(name, attribute).second;
DCHECK(unique) << "Duplicate attribute '" << name << "'";
}
tensorflow::Status Attributes::ParseValue(const string &str, string *value) {
*value = str;
return tensorflow::Status::OK();
}
tensorflow::Status Attributes::ParseValue(const string &str, bool *value) {
const string lowercased_str = tensorflow::str_util::Lowercase(str);
if (lowercased_str != "true" && lowercased_str != "false") {
return tensorflow::errors::InvalidArgument(
"Attribute can't be parsed as bool: ", str);
}
*value = lowercased_str == "true";
return tensorflow::Status::OK();
}
tensorflow::Status Attributes::ParseValue(const string &str, int32 *value) {
if (!tensorflow::strings::safe_strto32(str, value)) {
return tensorflow::errors::InvalidArgument(
"Attribute can't be parsed as int32: ", str);
}
return tensorflow::Status::OK();
}
tensorflow::Status Attributes::ParseValue(const string &str, int64 *value) {
if (!tensorflow::strings::safe_strto64(str, value)) {
return tensorflow::errors::InvalidArgument(
"Attribute can't be parsed as int64: ", str);
}
return tensorflow::Status::OK();
}
tensorflow::Status Attributes::ParseValue(const string &str, size_t *value) {
int64 signed_value = 0;
if (!tensorflow::strings::safe_strto64(str, &signed_value) ||
signed_value < 0) {
return tensorflow::errors::InvalidArgument(
"Attribute can't be parsed as size_t: ", str);
}
*value = signed_value;
return tensorflow::Status::OK();
}
tensorflow::Status Attributes::ParseValue(const string &str, float *value) {
if (!tensorflow::strings::safe_strtof(str.c_str(), value)) {
return tensorflow::errors::InvalidArgument(
"Attribute can't be parsed as float: ", str);
}
return tensorflow::Status::OK();
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Utils for parsing configuration attributes from (name,value) string pairs as
// typed values. Intended for parsing RegisteredModuleSpec.parameters, similar
// to get_attrs_with_defaults() in network_units.py. Example usage:
//
// // Create a subclass of Attributes.
// struct MyComponentAttributes : public Attributes {
// // Mandatory attribute with type and name. The "this" allows the attribute
// // to register itself in its container---i.e., MyComponentAttributes.
// Mandatory<float> coefficient{"coefficient", this};
//
// // Optional attributes with type, name, and default value.
// Optional<bool> ignore_case{"ignore_case", true, this};
// Optional<std::vector<int32>> layer_sizes{"layer_sizes", {1, 2, 3}, this};
//
// // Ignored attribute, which does not parse any value.
// Ignored dropout_keep_prob{"dropout_keep_prob", this};
// };
//
// // Initialize an instance of the subclass from a string-to-string mapping.
// RegisteredModuleSpec spec;
// MyComponentAttributes attributes;
// TF_RETURN_IF_ERROR(attributes.Reset(spec.parameters()));
//
// // Access the attributes as accessors.
// bool ignore_case = attributes.ignore_case();
// float coefficient = attributes.coefficient();
// const std::vector<int32> &layer_sizes = attributes.layer_sizes();
//
// See the unit test for additional usage examples.
//
// TODO(googleuser): Build typed attributes into the RegisteredModuleSpec and
// get rid of this module.
#ifndef DRAGNN_RUNTIME_ATTRIBUTES_H_
#define DRAGNN_RUNTIME_ATTRIBUTES_H_
#include <functional>
#include <map>
#include <string>
#include <vector>
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/protobuf.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Base class for sets of attributes. Use as indicated in the file comment.
class Attributes {
public:
// Untyped mapping from which typed attributes are parsed.
using Mapping = tensorflow::protobuf::Map<string, string>;
// Forbids copying, which would invalidate the pointers in |attributes_|.
Attributes(const Attributes &that) = delete;
Attributes &operator=(const Attributes &that) = delete;
// Parses registered attributes from the name-to-value |mapping|. On error,
// returns non-OK. Errors include unknown names in |mapping|, string-to-value
// parsing failures, and missing mandatory attributes.
tensorflow::Status Reset(const Mapping &mapping);
protected:
// Implementations of the supported kinds of attributes, defined below.
class Ignored;
template <class T>
class Optional;
template <class T>
class Mandatory;
// Forbids lifecycle management except via subclasses.
Attributes() = default;
virtual ~Attributes() = default;
private:
// Base class for an individual attribute, defined below.
class Attribute;
// Registers the |attribute| with the |name|, which must be unique.
void Register(const string &name, Attribute *attribute);
// Parses the string |str| into the |value| object.
static tensorflow::Status ParseValue(const string &str, string *value);
static tensorflow::Status ParseValue(const string &str, bool *value);
static tensorflow::Status ParseValue(const string &str, int32 *value);
static tensorflow::Status ParseValue(const string &str, int64 *value);
static tensorflow::Status ParseValue(const string &str, size_t *value);
static tensorflow::Status ParseValue(const string &str, float *value);
template <class Element>
static tensorflow::Status ParseValue(const string &str,
std::vector<Element> *value);
// Registered attributes, keyed by name.
std::map<string, Attribute *> attributes_;
};
// Implementation details below.
// Base class for individual attributes.
class Attributes::Attribute {
public:
Attribute() = default;
Attribute(const Attribute &that) = delete;
Attribute &operator=(const Attribute &that) = delete;
virtual ~Attribute() = default;
// Parses the |value| string into a typed object. On error, returns non-OK.
virtual tensorflow::Status Parse(const string &value) = 0;
// Returns true if this is a mandatory attribute. Defaults to optional.
virtual bool IsMandatory() const { return false; }
};
// Implements an ignored attribute.
class Attributes::Ignored : public Attribute {
public:
// Registers this in the |attributes| with the |name|.
Ignored(const string &name, Attributes *attributes) {
attributes->Register(name, this);
}
// Ignores the |value|.
tensorflow::Status Parse(const string &value) override {
return tensorflow::Status::OK();
}
};
// Implements an optional attribute.
template <class T>
class Attributes::Optional : public Attribute {
public:
// Registers this in the |attributes| with the |name| and |default_value|.
Optional(const string &name, const T &default_value, Attributes *attributes)
: value_(default_value) {
attributes->Register(name, this);
}
// Parses the |value| into the |value_|.
tensorflow::Status Parse(const string &value) override {
return ParseValue(value, &value_);
}
// Returns the parsed |value_|. Overloading operator() allows a struct member
// to be called like an accessor.
const T &operator()() const { return value_; }
private:
// The parsed value, or the default value if not explicitly specified.
T value_;
};
// Implements a mandatory attribute.
template <class T>
class Attributes::Mandatory : public Optional<T> {
public:
// Registers this in the |attributes| with the |name|.
Mandatory(const string &name, Attributes *attributes)
: Optional<T>(name, T(), attributes) {}
// Returns true since this is mandatory.
bool IsMandatory() const override { return true; }
private:
// The parsed value, or the default value if not explicitly specified.
T value_;
};
template <class Element>
tensorflow::Status Attributes::ParseValue(const string &str,
std::vector<Element> *value) {
value->clear();
if (!str.empty()) {
for (const string &element_str : tensorflow::str_util::Split(str, ",")) {
value->emplace_back();
TF_RETURN_IF_ERROR(ParseValue(element_str, &value->back()));
}
}
return tensorflow::Status::OK();
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_ATTRIBUTES_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/attributes.h"
#include <map>
#include <set>
#include <string>
#include <vector>
#include "dragnn/core/test/generic.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Returns the attribute mapping equivalent of the |std_map|.
Attributes::Mapping MakeMapping(const std::map<string, string> &std_map) {
Attributes::Mapping mapping;
for (const auto &it : std_map) mapping[it.first] = it.second;
return mapping;
}
// Returns a mapping with all attributes explicitly set.
Attributes::Mapping GetFullySpecifiedMapping() {
return MakeMapping({{"some_string", "explicit"},
{"some_bool", "true"},
{"some_int32", "987"},
{"some_int64", "654321"},
{"some_size_t", "7777777"},
{"some_float", "0.25"},
{"some_intvec", "2,3,5,7,11,13"},
{"some_strvec", "a,bc,def"}});
}
// A set of optional attributes.
struct OptionalAttributes : public Attributes {
Optional<string> some_string{"some_string", "default", this};
Optional<bool> some_bool{"some_bool", false, this};
Optional<int32> some_int32{"some_int32", 32, this};
Optional<int64> some_int64{"some_int64", 64, this};
Optional<size_t> some_size_t{"some_size_t", 999, this};
Optional<float> some_float{"some_float", -1.5, this};
Optional<std::vector<int32>> some_intvec{"some_intvec", {}, this};
Optional<std::vector<string>> some_strvec{"some_strvec", {"x", "y"}, this};
};
// Tests that attributes take their default values when they are not explicitly
// specified.
TEST(OptionalAttributesTest, Defaulted) {
Attributes::Mapping mapping;
OptionalAttributes attributes;
TF_ASSERT_OK(attributes.Reset(mapping));
EXPECT_EQ(attributes.some_string(), "default");
EXPECT_FALSE(attributes.some_bool());
EXPECT_EQ(attributes.some_int32(), 32);
EXPECT_EQ(attributes.some_int64(), 64);
EXPECT_EQ(attributes.some_size_t(), 999);
EXPECT_EQ(attributes.some_float(), -1.5);
EXPECT_EQ(attributes.some_intvec(), std::vector<int32>());
EXPECT_EQ(attributes.some_strvec(), std::vector<string>({"x", "y"}));
}
// Tests that attributes can be overridden to explicitly-specified values.
TEST(OptionalAttributesTest, FullySpecified) {
OptionalAttributes attributes;
TF_ASSERT_OK(attributes.Reset(GetFullySpecifiedMapping()));
EXPECT_EQ(attributes.some_string(), "explicit");
EXPECT_TRUE(attributes.some_bool());
EXPECT_EQ(attributes.some_int32(), 987);
EXPECT_EQ(attributes.some_int64(), 654321);
EXPECT_EQ(attributes.some_size_t(), 7777777);
EXPECT_EQ(attributes.some_float(), 0.25);
EXPECT_EQ(attributes.some_intvec(), std::vector<int32>({2, 3, 5, 7, 11, 13}));
EXPECT_EQ(attributes.some_strvec(), std::vector<string>({"a", "bc", "def"}));
}
// Tests that attribute parsing fails for an unknown name.
TEST(OptionalAttributesTest, UnknownName) {
const Attributes::Mapping mapping = MakeMapping({{"unknown", "##BAD##"}});
OptionalAttributes attributes;
EXPECT_THAT(attributes.Reset(mapping),
test::IsErrorWithSubstr("Unknown attribute"));
}
// Tests that attribute parsing fails for malformed bool values.
TEST(OptionalAttributesTest, BadBool) {
for (const string &value :
{" true", "true ", "tr ue", "arst", "1", "t", "y", "yes", " false",
"false ", "fa lse", "oien", "0", "f", "n", "no"}) {
const Attributes::Mapping mapping = MakeMapping({{"some_bool", value}});
OptionalAttributes attributes;
EXPECT_THAT(attributes.Reset(mapping),
test::IsErrorWithSubstr("Attribute can't be parsed as bool"));
}
}
// Tests that attribute parsing works for well-formed bool values.
TEST(OptionalAttributesTest, GoodBool) {
for (const string &value : {"true", "TRUE", "True", "tRuE"}) {
const Attributes::Mapping mapping = MakeMapping({{"some_bool", value}});
OptionalAttributes attributes;
TF_ASSERT_OK(attributes.Reset(mapping));
EXPECT_TRUE(attributes.some_bool());
}
for (const string &value : {"false", "FALSE", "False", "fAlSe"}) {
const Attributes::Mapping mapping = MakeMapping({{"some_bool", value}});
OptionalAttributes attributes;
TF_ASSERT_OK(attributes.Reset(mapping));
EXPECT_FALSE(attributes.some_bool());
}
}
// Tests that attribute parsing fails for malformed int32 values.
TEST(OptionalAttributesTest, BadInt32) {
for (const string &value : {"hello", "true", "1.0", "inf", "nan"}) {
const Attributes::Mapping mapping = MakeMapping({{"some_int32", value}});
OptionalAttributes attributes;
EXPECT_THAT(attributes.Reset(mapping),
test::IsErrorWithSubstr("Attribute can't be parsed as int32"));
}
}
// Tests that attribute parsing fails for malformed int64 values.
TEST(OptionalAttributesTest, BadInt64) {
for (const string &value : {"hello", "true", "1.0", "inf", "nan"}) {
const Attributes::Mapping mapping = MakeMapping({{"some_int64", value}});
OptionalAttributes attributes;
EXPECT_THAT(attributes.Reset(mapping),
test::IsErrorWithSubstr("Attribute can't be parsed as int64"));
}
}
// Tests that attribute parsing fails for malformed size_t values.
TEST(OptionalAttributesTest, BadSizeT) {
for (const string &value :
{"hello", "true", "1.0", "inf", "nan", "-1.0", "-123"}) {
const Attributes::Mapping mapping = MakeMapping({{"some_size_t", value}});
OptionalAttributes attributes;
EXPECT_THAT(attributes.Reset(mapping),
test::IsErrorWithSubstr("Attribute can't be parsed as size_t"));
}
}
// Tests that attribute parsing fails for malformed floats.
TEST(OptionalAttributesTest, BadFloat) {
for (const string &value : {"hello", "true"}) {
const Attributes::Mapping mapping = MakeMapping({{"some_float", value}});
OptionalAttributes attributes;
EXPECT_THAT(attributes.Reset(mapping),
test::IsErrorWithSubstr("Attribute can't be parsed as float"));
}
}
// Tests that attribute parsing fails for malformed std::vector<int32> values.
TEST(OptionalAttributesTest, BadIntVector) {
for (const string &value :
{"hello", "true", "1.0", "inf", "nan", "true,false", "foo,bar,baz"}) {
const Attributes::Mapping mapping = MakeMapping({{"some_intvec", value}});
OptionalAttributes attributes;
EXPECT_THAT(attributes.Reset(mapping),
test::IsErrorWithSubstr("Attribute can't be parsed as int32"));
}
}
// A set of mandatory attributes.
struct MandatoryAttributes : public Attributes {
Mandatory<string> some_string{"some_string", this};
Mandatory<bool> some_bool{"some_bool", this};
Mandatory<int32> some_int32{"some_int32", this};
Mandatory<int64> some_int64{"some_int64", this};
Mandatory<size_t> some_size_t{"some_size_t", this};
Mandatory<float> some_float{"some_float", this};
Mandatory<std::vector<int32>> some_intvec{"some_intvec", this};
Mandatory<std::vector<string>> some_strvec{"some_strvec", this};
};
// Tests that attribute parsing works when all mandatory attributes are
// explicitly specified.
TEST(MandatoryAttributesTest, FullySpecified) {
MandatoryAttributes attributes;
TF_ASSERT_OK(attributes.Reset(GetFullySpecifiedMapping()));
EXPECT_EQ(attributes.some_string(), "explicit");
EXPECT_TRUE(attributes.some_bool());
EXPECT_EQ(attributes.some_int32(), 987);
EXPECT_EQ(attributes.some_int64(), 654321);
EXPECT_EQ(attributes.some_size_t(), 7777777);
EXPECT_EQ(attributes.some_float(), 0.25);
EXPECT_EQ(attributes.some_intvec(), std::vector<int32>({2, 3, 5, 7, 11, 13}));
EXPECT_EQ(attributes.some_strvec(), std::vector<string>({"a", "bc", "def"}));
}
// Tests that attribute parsing fails when even one mandatory attribute is not
// explicitly specified.
TEST(MandatoryAttributesTest, MissingAttribute) {
for (const auto &it : GetFullySpecifiedMapping()) {
const string &name = it.first;
Attributes::Mapping mapping = GetFullySpecifiedMapping();
CHECK_EQ(mapping.erase(name), 1);
MandatoryAttributes attributes;
EXPECT_THAT(attributes.Reset(mapping),
test::IsErrorWithSubstr("Missing mandatory attributes"));
}
}
// A set of ignored attributes.
struct IgnoredAttributes : public Attributes {
Ignored foo{"foo", this};
Ignored bar{"bar", this};
Ignored baz{"baz", this};
};
// Tests that ignored attributes are not mandatory.
TEST(IgnoredAttributesTest, NotMandatory) {
const Attributes::Mapping mapping;
IgnoredAttributes attributes;
TF_ASSERT_OK(attributes.Reset(mapping));
}
// Tests that attribute parsing consumes ignored names.
TEST(IgnoredAttributesTest, IgnoredName) {
const Attributes::Mapping mapping =
MakeMapping({{"foo", "blah"}, {"bar", "123"}, {"baz", " "}});
IgnoredAttributes attributes;
TF_ASSERT_OK(attributes.Reset(mapping));
}
// Tests that attribute parsing still fails for unknown names.
TEST(IgnoredAttributesTest, UnknownName) {
const Attributes::Mapping mapping = MakeMapping(
{{"foo", "blah"}, {"bar", "123"}, {"baz", " "}, {"unknown", ""}});
IgnoredAttributes attributes;
EXPECT_THAT(attributes.Reset(mapping),
test::IsErrorWithSubstr("Unknown attribute"));
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include "dragnn/core/compute_session.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/math/eigen.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/network_unit.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Produces pairwise activations via a biaffine product between source and
// target token activations, as in the Dozat parser. This is the runtime
// version of the BiaffineDigraphNetwork, but is implemented as a Component
// instead of a NetworkUnit so it can control operand allocation.
class BiaffineDigraphComponent : public Component {
public:
// Implements Component.
tensorflow::Status Initialize(const ComponentSpec &component_spec,
VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) override;
tensorflow::Status Evaluate(SessionState *session_state,
ComputeSession *compute_session,
ComponentTrace *component_trace) const override;
bool Supports(const ComponentSpec &component_spec,
const string &normalized_builder_name) const override;
bool PreferredTo(const Component &other) const override { return false; }
private:
// Weights for computing source-target arc potentials.
Matrix<float> arc_weights_;
// Weights for computing source-selection potentials.
Vector<float> source_weights_;
// Weights and bias for root-target arc potentials.
Vector<float> root_weights_;
float root_bias_ = 0.0;
// Source and target token activation inputs.
LayerHandle<float> sources_handle_;
LayerHandle<float> targets_handle_;
// Directed adjacency matrix output.
PairwiseLayerHandle<float> adjacency_handle_;
// Handles for intermediate computations.
LocalMatrixHandle<float> target_product_handle_;
};
bool BiaffineDigraphComponent::Supports(
const ComponentSpec &component_spec,
const string &normalized_builder_name) const {
const string network_unit = NetworkUnit::GetClassName(component_spec);
return (normalized_builder_name == "BulkFeatureExtractorComponent" ||
normalized_builder_name == "BiaffineDigraphComponent") &&
network_unit == "BiaffineDigraphNetwork";
}
// Finds the link named |name| in the |component_spec| and points the |handle|
// at the corresponding layer in the |network_state_manager|. The layer must
// also match the |required_dimension|. Returns non-OK on error.
tensorflow::Status FindAndValidateLink(
const ComponentSpec &component_spec,
const NetworkStateManager &network_state_manager, const string &name,
size_t required_dimension, LayerHandle<float> *handle) {
const LinkedFeatureChannel *link = nullptr;
for (const LinkedFeatureChannel &channel : component_spec.linked_feature()) {
if (channel.name() == name) {
link = &channel;
break;
}
}
if (link == nullptr) {
return tensorflow::errors::InvalidArgument(
component_spec.name(), ": link '", name, "' does not exist");
}
const string error_suffix = tensorflow::strings::StrCat(
" in link { ", link->ShortDebugString(), " }");
if (link->embedding_dim() != -1) {
return tensorflow::errors::InvalidArgument(
component_spec.name(), ": transformed links are forbidden",
error_suffix);
}
if (link->size() != 1) {
return tensorflow::errors::InvalidArgument(
component_spec.name(), ": multi-embedding links are forbidden",
error_suffix);
}
if (link->source_component() == component_spec.name()) {
return tensorflow::errors::InvalidArgument(
component_spec.name(), ": recurrent links are forbidden", error_suffix);
}
if (link->fml() != "input.focus" || link->source_translator() != "identity") {
return tensorflow::errors::InvalidArgument(
component_spec.name(), ": non-trivial link translation is forbidden",
error_suffix);
}
size_t dimension = 0;
TF_RETURN_IF_ERROR(network_state_manager.LookupLayer(
link->source_component(), link->source_layer(), &dimension, handle));
if (dimension != required_dimension) {
return tensorflow::errors::InvalidArgument(
component_spec.name(), ": link '", name, "' has dimension ", dimension,
" instead of ", required_dimension, error_suffix);
}
return tensorflow::Status::OK();
}
tensorflow::Status BiaffineDigraphComponent::Initialize(
const ComponentSpec &component_spec, VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) {
TF_RETURN_IF_ERROR(variable_store->Lookup(
tensorflow::strings::StrCat(component_spec.name(), "/weights_arc"),
&arc_weights_));
const size_t source_dimension = arc_weights_.num_rows();
const size_t target_dimension = arc_weights_.num_columns();
TF_RETURN_IF_ERROR(variable_store->Lookup(
tensorflow::strings::StrCat(component_spec.name(), "/weights_source"),
&source_weights_));
if (source_weights_.size() != source_dimension) {
return tensorflow::errors::InvalidArgument(
component_spec.name(), ": dimension mismatch between weights_arc [",
source_dimension, ",", target_dimension, "] and weights_source [",
source_weights_.size(), "]");
}
TF_RETURN_IF_ERROR(variable_store->Lookup(
tensorflow::strings::StrCat(component_spec.name(), "/root_weights"),
&root_weights_));
if (root_weights_.size() != target_dimension) {
return tensorflow::errors::InvalidArgument(
component_spec.name(), ": dimension mismatch between weights_arc [",
source_dimension, ",", target_dimension, "] and root_weights [",
root_weights_.size(), "]");
}
Vector<float> root_bias;
TF_RETURN_IF_ERROR(variable_store->Lookup(
tensorflow::strings::StrCat(component_spec.name(), "/root_bias"),
&root_bias));
if (root_bias.size() != 1) {
return tensorflow::errors::InvalidArgument(
component_spec.name(), ": root_bias must be a singleton");
}
root_bias_ = root_bias[0];
if (component_spec.fixed_feature_size() != 0) {
return tensorflow::errors::InvalidArgument(
component_spec.name(), ": fixed features are forbidden");
}
if (component_spec.linked_feature_size() != 2) {
return tensorflow::errors::InvalidArgument(
component_spec.name(), ": two linked features are required");
}
TF_RETURN_IF_ERROR(FindAndValidateLink(component_spec, *network_state_manager,
"sources", source_dimension,
&sources_handle_));
TF_RETURN_IF_ERROR(FindAndValidateLink(component_spec, *network_state_manager,
"targets", target_dimension,
&targets_handle_));
TF_RETURN_IF_ERROR(
network_state_manager->AddLayer("adjacency", 1, &adjacency_handle_));
TF_RETURN_IF_ERROR(network_state_manager->AddLocal(source_dimension,
&target_product_handle_));
return tensorflow::Status::OK();
}
tensorflow::Status BiaffineDigraphComponent::Evaluate(
SessionState *session_state, ComputeSession *compute_session,
ComponentTrace *component_trace) const {
NetworkStates &network_states = session_state->network_states;
// Infer the number of steps from the source and target activations.
EigenMatrixMap<float> sources =
AsEigenMap(Matrix<float>(network_states.GetLayer(sources_handle_)));
EigenMatrixMap<float> targets =
AsEigenMap(Matrix<float>(network_states.GetLayer(targets_handle_)));
const size_t num_steps = sources.rows();
if (targets.rows() != num_steps) {
return tensorflow::errors::InvalidArgument(
"step count mismatch between sources (", num_steps, ") and targets (",
targets.rows(), ")");
}
// Since this component has a pairwise layer, allocate steps in one shot.
network_states.AddSteps(num_steps);
MutableEigenMatrixMap<float> adjacency =
AsEigenMap(network_states.GetLayer(adjacency_handle_));
MutableEigenMatrixMap<float> target_product =
AsEigenMap(network_states.GetLocal(target_product_handle_));
// First compute the adjacency matrix of combined arc and source scores.
// Note: .noalias() ensures that the RHS is assigned directly to the LHS;
// otherwise, Eigen may allocate a temp matrix to hold the result of the
// matmul on the RHS and then copy that to the LHS. See
// http://eigen.tuxfamily.org/dox/TopicLazyEvaluation.html
target_product.noalias() = targets * AsEigenMap(arc_weights_).transpose();
target_product.rowwise() += AsEigenMap(source_weights_);
adjacency.noalias() = target_product * sources.transpose();
// Now overwrite the diagonal with root-selection scores.
// Note: .array() allows the scalar addition of |root_bias_| to broadcast
// across the diagonal. See
// https://eigen.tuxfamily.org/dox/group__TutorialArrayClass.html
adjacency.diagonal().noalias() =
AsEigenMap(root_weights_) * targets.transpose();
adjacency.diagonal().array() += root_bias_;
return tensorflow::Status::OK();
}
DRAGNN_RUNTIME_REGISTER_COMPONENT(BiaffineDigraphComponent);
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <memory>
#include <string>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "syntaxnet/base.h"
#include <gmock/gmock.h>
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
using ::testing::Return;
constexpr size_t kNumSteps = 33;
constexpr size_t kSourceDim = 44;
constexpr size_t kTargetDim = 55;
constexpr size_t kBadDim = 11;
constexpr float kArcWeight = 1.0;
constexpr float kSourceWeight = 2.0;
constexpr float kRootWeight = 4.0;
constexpr float kRootBias = 8.0;
constexpr float kSourceValue = -0.5;
constexpr float kTargetValue = 1.5;
constexpr char kSourcesComponentName[] = "sources";
constexpr char kTargetsComponentName[] = "targets";
constexpr char kSourcesLayerName[] = "sources";
constexpr char kTargetsLayerName[] = "targets";
constexpr char kBadDimLayerName[] = "bad";
// Configuration for the Run() method. This makes it easier for tests to
// manipulate breakages.
struct RunConfig {
// Number of steps in the preceding components.
size_t sources_num_steps = kNumSteps;
size_t targets_num_steps = kNumSteps;
// Dimensions of the variables.
size_t weights_source_dim = kSourceDim;
size_t root_weights_dim = kTargetDim;
size_t root_bias_dim = 1;
};
class BiaffineDigraphComponentTest : public NetworkTestBase {
protected:
BiaffineDigraphComponentTest() {
EXPECT_CALL(compute_session_, GetInputBatchCache())
.WillRepeatedly(Return(&input_));
}
// Returns a working spec.
static ComponentSpec MakeGoodSpec() {
ComponentSpec component_spec;
component_spec.set_name(kTestComponentName);
component_spec.mutable_component_builder()->set_registered_name(
"bulk_component.BulkFeatureExtractorComponentBuilder");
component_spec.mutable_network_unit()->set_registered_name(
"biaffine_units.BiaffineDigraphNetwork");
for (const string &name : {kSourcesLayerName, kTargetsLayerName}) {
LinkedFeatureChannel *link = component_spec.add_linked_feature();
link->set_name(name);
link->set_embedding_dim(-1);
link->set_size(1);
link->set_source_component(name);
link->set_source_layer(name);
link->set_source_translator("identity");
link->set_fml("input.focus");
}
return component_spec;
}
// Creates a component, initializes it based on the |component_spec|, and
// evaluates it. On error, returns non-OK.
tensorflow::Status Run(const ComponentSpec &component_spec,
const RunConfig &config = RunConfig()) {
AddComponent(kSourcesComponentName);
AddLayer(kSourcesLayerName, kSourceDim);
AddComponent(kTargetsComponentName);
AddLayer(kTargetsLayerName, kTargetDim);
AddLayer(kBadDimLayerName, kBadDim);
AddComponent(kTestComponentName);
AddMatrixVariable(
tensorflow::strings::StrCat(kTestComponentName, "/weights_arc"),
kSourceDim, kTargetDim, kArcWeight);
AddVectorVariable(
tensorflow::strings::StrCat(kTestComponentName, "/weights_source"),
config.weights_source_dim, kSourceWeight);
AddVectorVariable(
tensorflow::strings::StrCat(kTestComponentName, "/root_weights"),
config.root_weights_dim, kRootWeight);
AddVectorVariable(
tensorflow::strings::StrCat(kTestComponentName, "/root_bias"),
config.root_bias_dim, kRootBias);
TF_RETURN_IF_ERROR(
Component::CreateOrError("BiaffineDigraphComponent", &component_));
TF_RETURN_IF_ERROR(component_->Initialize(component_spec, &variable_store_,
&network_state_manager_,
&extension_manager_));
network_states_.Reset(&network_state_manager_);
StartComponent(config.sources_num_steps);
FillLayer(kSourcesComponentName, kSourcesLayerName, kSourceValue);
StartComponent(config.targets_num_steps);
FillLayer(kTargetsComponentName, kTargetsLayerName, kTargetValue);
StartComponent(0); // BiaffineDigraphComponent will add steps
session_state_.extensions.Reset(&extension_manager_);
TF_RETURN_IF_ERROR(
component_->Evaluate(&session_state_, &compute_session_, nullptr));
adjacency_ = GetPairwiseLayer(kTestComponentName, "adjacency");
return tensorflow::Status::OK();
}
InputBatchCache input_;
std::unique_ptr<Component> component_;
Matrix<float> adjacency_;
};
// Tests that the good spec works properly.
TEST_F(BiaffineDigraphComponentTest, GoodSpec) {
TF_ASSERT_OK(Run(MakeGoodSpec()));
constexpr float kExpectedRootScore =
kRootWeight * kTargetValue * kTargetDim + kRootBias;
constexpr float kExpectedArcScore =
kSourceDim * kSourceValue * kArcWeight * kTargetValue * kTargetDim +
kSourceWeight * kSourceValue * kSourceDim;
ASSERT_EQ(adjacency_.num_rows(), kNumSteps);
ASSERT_EQ(adjacency_.num_columns(), kNumSteps);
for (size_t row = 0; row < kNumSteps; ++row) {
for (size_t column = 0; column < kNumSteps; ++column) {
if (row == column) {
ASSERT_EQ(adjacency_.row(row)[column], kExpectedRootScore);
} else {
ASSERT_EQ(adjacency_.row(row)[column], kExpectedArcScore);
}
}
}
}
// Tests the set of supported components.
TEST_F(BiaffineDigraphComponentTest, Supports) {
ComponentSpec component_spec = MakeGoodSpec();
string component_name;
TF_ASSERT_OK(Component::Select(component_spec, &component_name));
EXPECT_EQ(component_name, "BiaffineDigraphComponent");
component_spec.mutable_network_unit()->set_registered_name("bad");
EXPECT_THAT(Component::Select(component_spec, &component_name),
test::IsErrorWithSubstr("Could not find a best spec"));
component_spec = MakeGoodSpec();
component_spec.mutable_component_builder()->set_registered_name(
"BiaffineDigraphComponent");
TF_ASSERT_OK(Component::Select(component_spec, &component_name));
EXPECT_EQ(component_name, "BiaffineDigraphComponent");
component_spec.mutable_component_builder()->set_registered_name("bad");
EXPECT_THAT(Component::Select(component_spec, &component_name),
test::IsErrorWithSubstr("Could not find a best spec"));
}
// Tests that fixed features are rejected.
TEST_F(BiaffineDigraphComponentTest, FixedFeatures) {
ComponentSpec component_spec = MakeGoodSpec();
component_spec.add_fixed_feature();
EXPECT_THAT(Run(component_spec),
test::IsErrorWithSubstr("fixed features are forbidden"));
}
// Tests that too few linked features are rejected.
TEST_F(BiaffineDigraphComponentTest, TooFewLinkedFeatures) {
ComponentSpec component_spec = MakeGoodSpec();
component_spec.mutable_linked_feature()->RemoveLast();
EXPECT_THAT(Run(component_spec),
test::IsErrorWithSubstr("two linked features are required"));
}
// Tests that too many linked features are rejected.
TEST_F(BiaffineDigraphComponentTest, TooManyLinkedFeatures) {
ComponentSpec component_spec = MakeGoodSpec();
component_spec.add_linked_feature();
EXPECT_THAT(Run(component_spec),
test::IsErrorWithSubstr("two linked features are required"));
}
// Tests that a spec with no "sources" link is rejected.
TEST_F(BiaffineDigraphComponentTest, MissingSources) {
ComponentSpec component_spec = MakeGoodSpec();
component_spec.mutable_linked_feature(0)->set_name("bad");
EXPECT_THAT(Run(component_spec),
test::IsErrorWithSubstr("link 'sources' does not exist"));
}
// Tests that a spec with no "targets" link is rejected.
TEST_F(BiaffineDigraphComponentTest, MissingTargets) {
ComponentSpec component_spec = MakeGoodSpec();
component_spec.mutable_linked_feature(1)->set_name("bad");
EXPECT_THAT(Run(component_spec),
test::IsErrorWithSubstr("link 'targets' does not exist"));
}
// Tests that a spec with transformed links is rejected.
TEST_F(BiaffineDigraphComponentTest, TransformedLink) {
ComponentSpec component_spec = MakeGoodSpec();
component_spec.mutable_linked_feature(1)->set_embedding_dim(123);
EXPECT_THAT(Run(component_spec),
test::IsErrorWithSubstr("transformed links are forbidden"));
}
// Tests that a spec with multi-embedding links is rejected.
TEST_F(BiaffineDigraphComponentTest, MultiEmbeddingLink) {
ComponentSpec component_spec = MakeGoodSpec();
component_spec.mutable_linked_feature(1)->set_size(2);
EXPECT_THAT(Run(component_spec),
test::IsErrorWithSubstr("multi-embedding links are forbidden"));
}
// Tests that a spec with recurrent links is rejected.
TEST_F(BiaffineDigraphComponentTest, RecurrentLink) {
ComponentSpec component_spec = MakeGoodSpec();
component_spec.mutable_linked_feature(1)->set_source_component(
kTestComponentName);
EXPECT_THAT(Run(component_spec),
test::IsErrorWithSubstr("recurrent links are forbidden"));
}
// Tests that a spec with improper FML is rejected.
TEST_F(BiaffineDigraphComponentTest, BadFML) {
ComponentSpec component_spec = MakeGoodSpec();
component_spec.mutable_linked_feature(1)->set_fml("bad");
EXPECT_THAT(
Run(component_spec),
test::IsErrorWithSubstr("non-trivial link translation is forbidden"));
}
// Tests that a spec with non-identity links is rejected.
TEST_F(BiaffineDigraphComponentTest, NonIdentityLink) {
ComponentSpec component_spec = MakeGoodSpec();
component_spec.mutable_linked_feature(1)->set_source_translator("bad");
EXPECT_THAT(
Run(component_spec),
test::IsErrorWithSubstr("non-trivial link translation is forbidden"));
}
// Tests that a link with the wrong dimension is rejected.
TEST_F(BiaffineDigraphComponentTest, WrongLinkDimension) {
ComponentSpec component_spec = MakeGoodSpec();
component_spec.mutable_linked_feature(1)->set_source_layer(kBadDimLayerName);
EXPECT_THAT(
Run(component_spec),
test::IsErrorWithSubstr("link 'targets' has dimension 11 instead of 55"));
}
// Tests that a mismatched weights_source dimension is rejected.
TEST_F(BiaffineDigraphComponentTest, WeightsSourceDimensionMismatch) {
RunConfig config;
config.weights_source_dim = 999;
EXPECT_THAT(Run(MakeGoodSpec(), config),
test::IsErrorWithSubstr("dimension mismatch between weights_arc "
"[44,55] and weights_source [999]"));
}
// Tests that a mismatched root_weights dimension is rejected.
TEST_F(BiaffineDigraphComponentTest, RootWeightsDimensionMismatch) {
RunConfig config;
config.root_weights_dim = 999;
EXPECT_THAT(Run(MakeGoodSpec(), config),
test::IsErrorWithSubstr("dimension mismatch between weights_arc "
"[44,55] and root_weights [999]"));
}
// Tests that a mismatched root_bias dimension is rejected.
TEST_F(BiaffineDigraphComponentTest, RootBiasDimensionMismatch) {
RunConfig config;
config.root_bias_dim = 999;
EXPECT_THAT(Run(MakeGoodSpec(), config),
test::IsErrorWithSubstr("root_bias must be a singleton"));
}
// Tests that a mismatched number of steps is rejected.
TEST_F(BiaffineDigraphComponentTest, StepCountMismatch) {
RunConfig config;
config.targets_num_steps = 999;
EXPECT_THAT(
Run(MakeGoodSpec(), config),
test::IsErrorWithSubstr(
"step count mismatch between sources (33) and targets (999)"));
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <memory>
#include <string>
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "dragnn/runtime/bulk_network_unit.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/network_unit_base.h"
#include "dragnn/runtime/transition_system_traits.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Network unit that allows us to make calls to NetworkUnitBase and extract
// features. We may want to provide more optimized versions of this class.
class BulkFeatureExtractorNetwork : public NetworkUnitBase {
public:
// Returns true if this supports the |component_spec|. Requires:
// * A deterministic transition system, which can be advanced from the oracle.
// * No recurrent linked features (i.e. from this system).
static bool Supports(const ComponentSpec &component_spec);
// Implements NetworkUnit.
tensorflow::Status Initialize(const ComponentSpec &component_spec,
VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) override;
// Advances the |compute_session| through all oracle transitions and extracts
// fixed and linked embeddings, concatenates them into an input matrix stored
// in the NetworkStates in the |session_state|, and points the |inputs| at it.
// Also adds steps to the NetworkStates. On error, returns non-OK.
tensorflow::Status EvaluateInputs(SessionState *session_state,
ComputeSession *compute_session,
Matrix<float> *inputs) const;
private:
// Implements NetworkUnit. Evaluate() is "final" to encourage inlining.
string GetLogitsName() const override { return ""; }
tensorflow::Status Evaluate(size_t step_index, SessionState *session_state,
ComputeSession *compute_session) const final;
// Name of the containing component.
string name_;
// Concatenated input matrix.
LocalMatrixHandle<float> inputs_handle_;
};
bool BulkFeatureExtractorNetwork::Supports(
const ComponentSpec &component_spec) {
if (!TransitionSystemTraits(component_spec).is_deterministic) return false;
// Forbid recurrent linked features.
for (const LinkedFeatureChannel &channel : component_spec.linked_feature()) {
if (channel.source_component() == component_spec.name()) return false;
}
return true;
}
tensorflow::Status BulkFeatureExtractorNetwork::Initialize(
const ComponentSpec &component_spec, VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) {
name_ = component_spec.name();
if (!Supports(component_spec)) {
return tensorflow::errors::InvalidArgument(
"BulkFeatureExtractorNetwork does not support component '", name_, "'");
}
const bool use_concatenated_input = true;
TF_RETURN_IF_ERROR(InitializeBase(use_concatenated_input, component_spec,
variable_store, network_state_manager,
extension_manager));
return network_state_manager->AddLocal(concatenated_input_dim(),
&inputs_handle_);
}
tensorflow::Status BulkFeatureExtractorNetwork::EvaluateInputs(
SessionState *session_state, ComputeSession *compute_session,
Matrix<float> *inputs) const {
// TODO(googleuser): Try the ComputeSession's bulk feature extraction API?
for (size_t step_idx = 0; !compute_session->IsTerminal(name_); ++step_idx) {
session_state->network_states.AddStep();
TF_RETURN_IF_ERROR(Evaluate(step_idx, session_state, compute_session));
compute_session->AdvanceFromOracle(name_);
}
*inputs = session_state->network_states.GetLocal(inputs_handle_);
return tensorflow::Status::OK();
}
tensorflow::Status BulkFeatureExtractorNetwork::Evaluate(
size_t step_index, SessionState *session_state,
ComputeSession *compute_session) const {
Vector<float> input;
TF_RETURN_IF_ERROR(EvaluateBase(session_state, compute_session, &input));
MutableMatrix<float> all_inputs =
session_state->network_states.GetLocal(inputs_handle_);
// TODO(googleuser): Punch a hole in EvaluateBase so it writes directly to
// all_inputs.row(step_index).
//
// In the future, we could entirely eliminate copying, by providing a variant
// of LstmCellFunction::RunInputComputation that adds a partial vector of
// inputs, e.g. instead of RunInputComputation(x), we compute
//
// RunInputComputation(x[0:32]) + RunInputComputation(x[32:64])
//
// where perhaps x[0:32] points directly at a fixed word feature vector, and
// x[32:64] points directly at the previous layer's outputs (as a linked
// feature).
MutableVector<float> output = all_inputs.row(step_index);
DCHECK_EQ(input.size(), output.size());
// TODO(googleuser): Try memcpy() or a custom vectorized copy.
for (int i = 0; i < input.size(); ++i) {
output[i] = input[i];
}
return tensorflow::Status::OK();
}
// Bulk version of a DynamicComponent---i.e., a component that was originally
// dynamic but can be automatically upgraded to a bulk version.
class BulkDynamicComponent : public Component {
protected:
// Implements Component.
tensorflow::Status Initialize(const ComponentSpec &component_spec,
VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) override;
tensorflow::Status Evaluate(SessionState *session_state,
ComputeSession *compute_session,
ComponentTrace *component_trace) const override;
bool Supports(const ComponentSpec &component_spec,
const string &normalized_builder_name) const override;
bool PreferredTo(const Component &other) const override { return true; }
private:
// Feature extractor that builds the input activation matrix.
BulkFeatureExtractorNetwork bulk_feature_extractor_;
// Network unit for bulk computation.
std::unique_ptr<BulkNetworkUnit> bulk_network_unit_;
};
// In addition to the BulkFeatureExtractorNetwork requirements, the bulk LSTM
// requires no attention (the runtime doesn't support attention yet).
bool BulkDynamicComponent::Supports(
const ComponentSpec &component_spec,
const string &normalized_builder_name) const {
return BulkFeatureExtractorNetwork::Supports(component_spec) &&
(normalized_builder_name == "DynamicComponent" ||
normalized_builder_name == "BulkDynamicComponent") &&
component_spec.attention_component().empty();
}
tensorflow::Status BulkDynamicComponent::Initialize(
const ComponentSpec &component_spec, VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) {
TF_RETURN_IF_ERROR(bulk_feature_extractor_.Initialize(
component_spec, variable_store, network_state_manager,
extension_manager));
TF_RETURN_IF_ERROR(BulkNetworkUnit::CreateOrError(
BulkNetworkUnit::GetClassName(component_spec), &bulk_network_unit_));
TF_RETURN_IF_ERROR(
bulk_network_unit_->Initialize(component_spec, variable_store,
network_state_manager, extension_manager));
return bulk_network_unit_->ValidateInputDimension(
bulk_feature_extractor_.concatenated_input_dim());
}
tensorflow::Status BulkDynamicComponent::Evaluate(
SessionState *session_state, ComputeSession *compute_session,
ComponentTrace *component_trace) const {
Matrix<float> inputs;
TF_RETURN_IF_ERROR(bulk_feature_extractor_.EvaluateInputs(
session_state, compute_session, &inputs));
return bulk_network_unit_->Evaluate(inputs, session_state);
}
DRAGNN_RUNTIME_REGISTER_COMPONENT(BulkDynamicComponent);
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <memory>
#include <string>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/bulk_network_unit.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include <gmock/gmock.h>
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
using ::testing::_;
using ::testing::Invoke;
using ::testing::Return;
constexpr size_t kNumSteps = 50;
constexpr size_t kFixedDim = 11;
constexpr size_t kFixedVocabularySize = 123;
constexpr float kFixedValue = 0.5;
constexpr size_t kLinkedDim = 13;
constexpr float kLinkedValue = 1.25;
constexpr char kPreviousComponentName[] = "previous_component";
constexpr char kPreviousLayerName[] = "previous_layer";
constexpr char kOutputsName[] = "outputs";
constexpr size_t kOutputsDim = kFixedDim + kLinkedDim;
// Adds one to all inputs.
class BulkAddOne : public BulkNetworkUnit {
public:
// Implements BulkNetworkUnit.
tensorflow::Status Initialize(const ComponentSpec &component_spec,
VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) override {
return network_state_manager->AddLayer(kOutputsName, kOutputsDim,
&outputs_handle_);
}
tensorflow::Status ValidateInputDimension(size_t dimension) const override {
return tensorflow::Status::OK();
}
string GetLogitsName() const override { return ""; }
tensorflow::Status Evaluate(Matrix<float> inputs,
SessionState *session_state) const override {
const MutableMatrix<float> outputs =
session_state->network_states.GetLayer(outputs_handle_);
if (outputs.num_rows() != inputs.num_rows() ||
outputs.num_columns() != inputs.num_columns()) {
return tensorflow::errors::InvalidArgument("Dimension mismatch");
}
for (size_t row = 0; row < inputs.num_rows(); ++row) {
for (size_t column = 0; column < inputs.num_columns(); ++column) {
outputs.row(row)[column] = inputs.row(row)[column] + 1.0;
}
}
return tensorflow::Status::OK();
}
private:
// Output outputs.
LayerHandle<float> outputs_handle_;
};
DRAGNN_RUNTIME_REGISTER_BULK_NETWORK_UNIT(BulkAddOne);
// A component that also prefers itself but is triggered on a certain backend.
// This can be used to cause a component selection conflict.
class ImTheBest : public Component {
public:
// Implements Component.
tensorflow::Status Initialize(const ComponentSpec &component_spec,
VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) override {
return tensorflow::Status::OK();
}
tensorflow::Status Evaluate(SessionState *session_state,
ComputeSession *compute_session,
ComponentTrace *component_trace) const override {
return tensorflow::Status::OK();
}
bool Supports(const ComponentSpec &component_spec,
const string &normalized_builder_name) const override {
return component_spec.backend().registered_name() == "CauseConflict";
}
bool PreferredTo(const Component &other) const override { return true; }
};
DRAGNN_RUNTIME_REGISTER_COMPONENT(ImTheBest);
class BulkDynamicComponentTest : public NetworkTestBase {
protected:
// Returns a spec that the network supports.
ComponentSpec GetSupportedSpec() {
ComponentSpec component_spec;
component_spec.set_name(kTestComponentName);
component_spec.set_num_actions(1);
component_spec.mutable_network_unit()->set_registered_name("AddOne");
component_spec.mutable_component_builder()->set_registered_name(
"DynamicComponent");
FixedFeatureChannel *fixed_feature = component_spec.add_fixed_feature();
fixed_feature->set_size(1);
fixed_feature->set_embedding_dim(kFixedDim);
fixed_feature->set_vocabulary_size(kFixedVocabularySize);
LinkedFeatureChannel *linked_feature = component_spec.add_linked_feature();
linked_feature->set_size(1);
linked_feature->set_embedding_dim(-1);
linked_feature->set_source_component(kPreviousComponentName);
linked_feature->set_source_layer(kPreviousLayerName);
return component_spec;
}
// Adds mock call expectations to the |compute_session_| for the transition
// system traversal and feature extraction.
void AddComputeSessionMocks() {
SetupTransitionLoop(kNumSteps);
EXPECT_CALL(compute_session_, AdvanceFromOracle(_)).Times(kNumSteps);
EXPECT_CALL(compute_session_, GetInputFeatures(_, _, _, _, _))
.Times(kNumSteps)
.WillRepeatedly(Invoke(ExtractFeatures(0, {{0, 1.0}})));
EXPECT_CALL(compute_session_, GetTranslatedLinkFeatures(_, _))
.Times(kNumSteps)
.WillRepeatedly(Invoke(ExtractLinks(0, {"step_idx: 0"})));
EXPECT_CALL(compute_session_, SourceComponentBeamSize(_, _))
.Times(kNumSteps)
.WillRepeatedly(Return(1));
}
// Creates a network unit, initializes it based on the |component_spec_text|,
// and evaluates it. On error, returns non-OK.
tensorflow::Status Run(const ComponentSpec &component_spec) {
AddComponent(kPreviousComponentName);
AddLayer(kPreviousLayerName, kLinkedDim);
AddComponent(kTestComponentName);
AddFixedEmbeddingMatrix(0, kFixedVocabularySize, kFixedDim, kFixedValue);
TF_RETURN_IF_ERROR(
Component::CreateOrError("BulkDynamicComponent", &component_));
TF_RETURN_IF_ERROR(component_->Initialize(component_spec, &variable_store_,
&network_state_manager_,
&extension_manager_));
// Allocates network states for a few steps.
network_states_.Reset(&network_state_manager_);
StartComponent(kNumSteps);
FillLayer(kPreviousComponentName, kPreviousLayerName, kLinkedValue);
StartComponent(0);
session_state_.extensions.Reset(&extension_manager_);
TF_RETURN_IF_ERROR(
component_->Evaluate(&session_state_, &compute_session_, nullptr));
outputs_ = GetLayer(kTestComponentName, kOutputsName);
return tensorflow::Status::OK();
}
std::unique_ptr<Component> component_;
Matrix<float> outputs_;
};
// Tests that the supported spec is supported.
TEST_F(BulkDynamicComponentTest, Supported) {
const ComponentSpec component_spec = GetSupportedSpec();
string component_type;
TF_ASSERT_OK(Component::Select(component_spec, &component_type));
EXPECT_EQ(component_type, "BulkDynamicComponent");
AddComputeSessionMocks();
TF_ASSERT_OK(Run(component_spec));
ASSERT_EQ(outputs_.num_rows(), kNumSteps);
ASSERT_EQ(outputs_.num_columns(), kFixedDim + kLinkedDim);
for (size_t row = 0; row < kNumSteps; ++row) {
size_t column = 0;
for (; column < kFixedDim; ++column) {
EXPECT_EQ(outputs_.row(row)[column], kFixedValue + 1.0);
}
for (; column < kFixedDim + kLinkedDim; ++column) {
EXPECT_EQ(outputs_.row(row)[column], kLinkedValue + 1.0);
}
}
}
// Tests that the BulkDynamicComponent also supports its own name.
TEST_F(BulkDynamicComponentTest, SupportsBulkName) {
ComponentSpec component_spec = GetSupportedSpec();
component_spec.mutable_component_builder()->set_registered_name(
"BulkDynamicComponent");
string component_type;
TF_ASSERT_OK(Component::Select(component_spec, &component_type));
EXPECT_EQ(component_type, "BulkDynamicComponent");
}
// Tests that the transition system must be deterministic.
TEST_F(BulkDynamicComponentTest, ForbidNonDeterminism) {
ComponentSpec component_spec = GetSupportedSpec();
component_spec.set_num_actions(100);
string component_type;
EXPECT_THAT(
Component::Select(component_spec, &component_type),
test::IsErrorWithSubstr("Could not find a best spec for component"));
EXPECT_THAT(Run(component_spec),
test::IsErrorWithSubstr(
"BulkFeatureExtractorNetwork does not support component"));
}
// Tests that links cannot be recurrent.
TEST_F(BulkDynamicComponentTest, ForbidRecurrences) {
ComponentSpec component_spec = GetSupportedSpec();
component_spec.mutable_linked_feature(0)->set_source_component(
kTestComponentName);
string component_type;
EXPECT_THAT(
Component::Select(component_spec, &component_type),
test::IsErrorWithSubstr("Could not find a best spec for component"));
EXPECT_THAT(Run(component_spec),
test::IsErrorWithSubstr(
"BulkFeatureExtractorNetwork does not support component"));
}
// Tests that the component prefers itself.
TEST_F(BulkDynamicComponentTest, PrefersItself) {
ComponentSpec component_spec = GetSupportedSpec();
component_spec.mutable_backend()->set_registered_name("CauseConflict");
// The "CauseConflict" backend triggers the ImTheBest component, which also
// prefers itself and leads to a selection conflict.
string component_type;
EXPECT_THAT(Component::Select(component_spec, &component_type),
test::IsErrorWithSubstr("both think they should be preferred"));
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <string>
#include "dragnn/core/compute_session.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/bulk_network_unit.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/feed_forward_network_kernel.h"
#include "dragnn/runtime/feed_forward_network_layer.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// A network unit that evaluates a feed-forward multi-layer perceptron.
class BulkFeedForwardNetwork : public BulkNetworkUnit {
public:
// Implements BulkNetworkUnit.
tensorflow::Status Initialize(const ComponentSpec &component_spec,
VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) override;
tensorflow::Status ValidateInputDimension(size_t dimension) const override;
string GetLogitsName() const override { return kernel_.logits_name(); }
tensorflow::Status Evaluate(Matrix<float> inputs,
SessionState *session_state) const override;
private:
// Kernel that implements the feed-forward network.
FeedForwardNetworkKernel kernel_;
};
tensorflow::Status BulkFeedForwardNetwork::Initialize(
const ComponentSpec &component_spec, VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) {
for (const LinkedFeatureChannel &channel : component_spec.linked_feature()) {
if (channel.source_component() == component_spec.name()) {
return tensorflow::errors::InvalidArgument(
"BulkFeedForwardNetwork forbids recurrent links");
}
}
return kernel_.Initialize(component_spec, variable_store,
network_state_manager);
}
tensorflow::Status BulkFeedForwardNetwork::ValidateInputDimension(
size_t dimension) const {
return kernel_.ValidateInputDimension(dimension);
}
tensorflow::Status BulkFeedForwardNetwork::Evaluate(
Matrix<float> inputs, SessionState *session_state) const {
for (const FeedForwardNetworkLayer &layer : kernel_.layers()) {
inputs = layer.Apply(inputs, session_state->network_states);
}
return tensorflow::Status::OK();
}
DRAGNN_RUNTIME_REGISTER_BULK_NETWORK_UNIT(BulkFeedForwardNetwork);
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <algorithm>
#include <memory>
#include <string>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/bulk_network_unit.h"
#include "dragnn/runtime/flexible_matrix_kernel.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
constexpr size_t kInputDim = 5;
constexpr size_t kLogitsDim = 3;
constexpr size_t kNumSteps = 4;
constexpr float kEmbedding = 1.25;
// Applies the ReLU activation to the |value|.
float Relu(float value) { return std::max(0.0f, value); }
class BulkFeedForwardNetworkTest : public NetworkTestBase {
protected:
// Adds a weight matrix with the |name_suffix| with the given dimensions and
// |fill_value|.
void AddWeights(const string &name_suffix, size_t num_rows,
size_t num_columns, float fill_value) {
const string weights_name =
tensorflow::strings::StrCat(kTestComponentName, "/weights_",
name_suffix, FlexibleMatrixKernel::kSuffix);
AddMatrixVariable(weights_name, num_columns, num_rows, fill_value);
}
// Adds a bias vector with the |name_suffix| with the given dimensions and
// |fill_value|.
void AddBiases(const string &name_suffix, size_t dimension,
float fill_value) {
const string biases_name =
tensorflow::strings::StrCat(kTestComponentName, "/bias_", name_suffix);
AddVectorVariable(biases_name, dimension, fill_value);
}
// Creates a network unit, initializes it based on the |component_spec_text|,
// and evaluates it. On error, returns non-OK.
tensorflow::Status Run(const string &component_spec_text) {
ComponentSpec component_spec;
CHECK(TextFormat::ParseFromString(component_spec_text, &component_spec));
component_spec.set_name(kTestComponentName);
AddComponent(kTestComponentName);
TF_CHECK_OK(BulkNetworkUnit::CreateOrError("BulkFeedForwardNetwork",
&bulk_network_unit_));
TF_RETURN_IF_ERROR(bulk_network_unit_->Initialize(
component_spec, &variable_store_, &network_state_manager_,
&extension_manager_));
size_t input_dimension = 0;
for (const FixedFeatureChannel &channel : component_spec.fixed_feature()) {
input_dimension += channel.embedding_dim();
}
TF_RETURN_IF_ERROR(
bulk_network_unit_->ValidateInputDimension(input_dimension));
network_states_.Reset(&network_state_manager_);
StartComponent(kNumSteps);
session_state_.extensions.Reset(&extension_manager_);
const std::vector<float> row(kInputDim, kEmbedding);
UniqueMatrix<float> input(std::vector<std::vector<float>>(kNumSteps, row));
return bulk_network_unit_->Evaluate(Matrix<float>(*input), &session_state_);
}
// Returns the layer named |layer_name| in the current component.
Matrix<float> GetActivations(const string &layer_name) const {
return Matrix<float>(GetLayer(kTestComponentName, layer_name));
}
std::unique_ptr<BulkNetworkUnit> bulk_network_unit_;
};
// Tests that BulkFeedForwardNetwork fails when a weight matrix does not match
// the dimension of its output activations.
TEST_F(BulkFeedForwardNetworkTest, BadWeightRows) {
const string kBadSpec = R"(fixed_feature {
vocabulary_size: 50
embedding_dim: 5
size: 1
}
num_actions: 3)";
AddWeights("softmax", kInputDim, kLogitsDim - 1 /* bad */, 1.0);
AddBiases("softmax", kLogitsDim, 1.0);
EXPECT_THAT(
Run(kBadSpec),
test::IsErrorWithSubstr(
"Weight matrix shape should be output dimension plus padding"));
}
// Tests that BulkFeedForwardNetwork fails when a weight matrix does not match
// the dimension of its input activations.
TEST_F(BulkFeedForwardNetworkTest, BadWeightColumns) {
const string kBadSpec = R"(fixed_feature {
vocabulary_size: 50
embedding_dim: 5
size: 1
}
num_actions: 3)";
AddWeights("softmax", kInputDim + 1 /* bad */, kLogitsDim, 1.0);
AddBiases("softmax", kLogitsDim, 1.0);
EXPECT_THAT(Run(kBadSpec),
test::IsErrorWithSubstr(
"Weight matrix shape does not match input dimension"));
}
// Tests that BulkFeedForwardNetwork fails when a bias vector does not match the
// dimension of its output activations.
TEST_F(BulkFeedForwardNetworkTest, BadBiasDimension) {
const string kBadSpec = R"(fixed_feature {
vocabulary_size: 50
embedding_dim: 5
size: 1
}
num_actions: 3)";
AddWeights("softmax", kInputDim, kLogitsDim, 1.0);
AddBiases("softmax", kLogitsDim + 1 /* bad */, 1.0);
EXPECT_THAT(Run(kBadSpec),
test::IsErrorWithSubstr(
"Bias vector shape does not match output dimension"));
}
// Tests that BulkFeedForwardNetwork fails when the value of the
// "layer_norm_input" option is not false.
TEST_F(BulkFeedForwardNetworkTest, UnsupportedLayerNormInputOption) {
const string kBadSpec = R"(network_unit {
parameters {
key: 'layer_norm_input'
value: 'true'
}
})";
EXPECT_THAT(Run(kBadSpec),
test::IsErrorWithSubstr("Layer norm is not supported"));
}
// Tests that BulkFeedForwardNetwork fails when the value of the
// "layer_norm_hidden" option is not false.
TEST_F(BulkFeedForwardNetworkTest, UnsupportedLayerNormHiddenOption) {
const string kBadSpec = R"(network_unit {
parameters {
key: 'layer_norm_hidden'
value: 'true'
}
})";
EXPECT_THAT(Run(kBadSpec),
test::IsErrorWithSubstr("Layer norm is not supported"));
}
// Tests that BulkFeedForwardNetwork fails when the value of the "nonlinearity"
// option is not "relu".
TEST_F(BulkFeedForwardNetworkTest, UnsupportedNonlinearityOption) {
const string kBadSpec = R"(network_unit {
parameters {
key: 'nonlinearity'
value: 'elu'
}
})";
EXPECT_THAT(Run(kBadSpec),
test::IsErrorWithSubstr("Non-linearity is not supported"));
}
// Tests that BulkFeedForwardNetwork fails if there is a recurrent link.
TEST_F(BulkFeedForwardNetworkTest, UnsupportedRecurrentLink) {
const string kBadSpec = R"(linked_feature {
source_component: 'test_component'
})";
EXPECT_THAT(Run(kBadSpec),
test::IsErrorWithSubstr(
"BulkFeedForwardNetwork forbids recurrent links"));
}
// Tests that the BulkFeedForwardNetwork works when there are no hidden layers,
// just a softmax that computes logits.
TEST_F(BulkFeedForwardNetworkTest, JustLogits) {
const string kSpec = R"(fixed_feature {
vocabulary_size: 50
embedding_dim: 5
size: 1
}
num_actions: 3)";
const float kWeight = 1.5;
const float kBias = 0.75;
AddWeights("softmax", kInputDim, kLogitsDim, kWeight);
AddBiases("softmax", kLogitsDim, kBias);
TF_ASSERT_OK(Run(kSpec));
EXPECT_EQ("logits", bulk_network_unit_->GetLogitsName());
ExpectMatrix(GetActivations("logits"), kNumSteps, kLogitsDim,
kInputDim * kEmbedding * kWeight + kBias);
}
// Tests that the BulkFeedForwardNetwork works with multiple hidden layers as
// well as a softmax that computes logits.
TEST_F(BulkFeedForwardNetworkTest, MultiLayer) {
const size_t kDims[] = {kInputDim, 4, 3, 2};
const string kSpec = R"(fixed_feature {
vocabulary_size: 50
embedding_dim: 5
size: 1
}
network_unit {
parameters {
key: 'hidden_layer_sizes'
value: '4,3'
}
}
num_actions: 2)";
const float kWeights[] = {-1.5, 1.0, 0.5};
const float kBiases[] = {0.75, -0.5, -1.0};
AddWeights("0", kDims[0], kDims[1], kWeights[0]);
AddBiases("0", kDims[1], kBiases[0]);
AddWeights("1", kDims[1], kDims[2], kWeights[1]);
AddBiases("1", kDims[2], kBiases[1]);
AddWeights("softmax", kDims[2], kDims[3], kWeights[2]);
AddBiases("softmax", kDims[3], kBiases[2]);
TF_ASSERT_OK(Run(kSpec));
EXPECT_EQ("logits", bulk_network_unit_->GetLogitsName());
float expected = Relu(kDims[0] * kWeights[0] + kBiases[0]);
ExpectMatrix(GetActivations("layer_0"), kNumSteps, kDims[1], expected);
expected = Relu(kDims[1] * expected * kWeights[1] + kBiases[1]);
ExpectMatrix(GetActivations("layer_1"), kNumSteps, kDims[2], expected);
ExpectMatrix(GetActivations("last_layer"), kNumSteps, kDims[2], expected);
expected = kDims[2] * expected * kWeights[2] + kBiases[2];
ExpectMatrix(GetActivations("logits"), kNumSteps, kDims[3], expected);
}
// Tests that the BulkFeedForwardNetwork does not produce logits and does not
// use the softmax variables when the component is deterministic.
TEST_F(BulkFeedForwardNetworkTest, NoLogitsOrSoftmaxWhenDeterministic) {
const size_t kDims[] = {kInputDim, 4};
const string kSpec = R"(num_actions: 1
fixed_feature {
vocabulary_size: 50
embedding_dim: 5
size: 1
}
network_unit {
parameters {
key: 'hidden_layer_sizes'
value: '4'
}
})";
const float kWeight = -1.5;
const float kBias = 0.75;
// No "softmax" weights or biases.
AddWeights("0", kDims[0], kDims[1], kWeight);
AddBiases("0", kDims[1], kBias);
TF_ASSERT_OK(Run(kSpec));
// No specified logits layer.
EXPECT_TRUE(bulk_network_unit_->GetLogitsName().empty());
// No "logits" layer.
size_t unused_dimension = 0;
LayerHandle<float> unused_handle;
EXPECT_THAT(
network_state_manager_.LookupLayer(kTestComponentName, "logits",
&unused_dimension, &unused_handle),
test::IsErrorWithSubstr(
"Unknown layer 'logits' in component 'test_component'"));
// Hidden layer is still produced.
const float kExpected = Relu(kDims[0] * kEmbedding * kWeight + kBias);
ExpectMatrix(GetActivations("layer_0"), kNumSteps, kDims[1], kExpected);
ExpectMatrix(GetActivations("last_layer"), kNumSteps, kDims[1], kExpected);
}
// Tests that the BulkFeedForwardNetwork does not produce logits when
// omit_logits is true, even if there are actions.
TEST_F(BulkFeedForwardNetworkTest, NoLogitsOrSoftmaxWhenOmitLogitsTrue) {
const size_t kDims[] = {kInputDim, 4};
const string kSpec = R"(fixed_feature {
vocabulary_size: 50
embedding_dim: 5
size: 1
}
network_unit {
parameters {
key: 'hidden_layer_sizes'
value: '4'
}
parameters {
key: 'omit_logits'
value: 'true'
}
}
num_actions: 10)";
const float kWeight = 1.5;
const float kBias = 0.75;
// No "softmax" weights or biases.
AddWeights("0", kDims[0], kDims[1], kWeight);
AddBiases("0", kDims[1], kBias);
TF_ASSERT_OK(Run(kSpec));
// No specified logits layer.
EXPECT_TRUE(bulk_network_unit_->GetLogitsName().empty());
// No "logits" layer.
size_t unused_dimension = 0;
LayerHandle<float> unused_handle;
EXPECT_THAT(
network_state_manager_.LookupLayer(kTestComponentName, "logits",
&unused_dimension, &unused_handle),
test::IsErrorWithSubstr(
"Unknown layer 'logits' in component 'test_component'"));
// Hidden layer is still produced.
const float kExpected = kDims[0] * kEmbedding * kWeight + kBias;
ExpectMatrix(GetActivations("layer_0"), kNumSteps, kDims[1], kExpected);
ExpectMatrix(GetActivations("last_layer"), kNumSteps, kDims[1], kExpected);
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment