Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
a4bb31d0
Commit
a4bb31d0
authored
May 02, 2018
by
Terry Koo
Browse files
Export @195097388.
parent
dea7ecf6
Changes
294
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3241 additions
and
0 deletions
+3241
-0
research/syntaxnet/dragnn/runtime/network_unit.cc
research/syntaxnet/dragnn/runtime/network_unit.cc
+43
-0
research/syntaxnet/dragnn/runtime/network_unit.h
research/syntaxnet/dragnn/runtime/network_unit.h
+95
-0
research/syntaxnet/dragnn/runtime/network_unit_base.cc
research/syntaxnet/dragnn/runtime/network_unit_base.cc
+171
-0
research/syntaxnet/dragnn/runtime/network_unit_base.h
research/syntaxnet/dragnn/runtime/network_unit_base.h
+137
-0
research/syntaxnet/dragnn/runtime/network_unit_base_test.cc
research/syntaxnet/dragnn/runtime/network_unit_base_test.cc
+403
-0
research/syntaxnet/dragnn/runtime/network_unit_test.cc
research/syntaxnet/dragnn/runtime/network_unit_test.cc
+82
-0
research/syntaxnet/dragnn/runtime/operands.cc
research/syntaxnet/dragnn/runtime/operands.cc
+142
-0
research/syntaxnet/dragnn/runtime/operands.h
research/syntaxnet/dragnn/runtime/operands.h
+236
-0
research/syntaxnet/dragnn/runtime/operands_test.cc
research/syntaxnet/dragnn/runtime/operands_test.cc
+350
-0
research/syntaxnet/dragnn/runtime/recurrent_sequence_linkers.cc
...ch/syntaxnet/dragnn/runtime/recurrent_sequence_linkers.cc
+96
-0
research/syntaxnet/dragnn/runtime/recurrent_sequence_linkers_test.cc
...ntaxnet/dragnn/runtime/recurrent_sequence_linkers_test.cc
+151
-0
research/syntaxnet/dragnn/runtime/reversed_sequence_linker.cc
...arch/syntaxnet/dragnn/runtime/reversed_sequence_linker.cc
+76
-0
research/syntaxnet/dragnn/runtime/reversed_sequence_linker_test.cc
...syntaxnet/dragnn/runtime/reversed_sequence_linker_test.cc
+129
-0
research/syntaxnet/dragnn/runtime/select_best_component_transformer.cc
...axnet/dragnn/runtime/select_best_component_transformer.cc
+58
-0
research/syntaxnet/dragnn/runtime/select_best_component_transformer_test.cc
.../dragnn/runtime/select_best_component_transformer_test.cc
+118
-0
research/syntaxnet/dragnn/runtime/sequence_backend.cc
research/syntaxnet/dragnn/runtime/sequence_backend.cc
+152
-0
research/syntaxnet/dragnn/runtime/sequence_backend.h
research/syntaxnet/dragnn/runtime/sequence_backend.h
+124
-0
research/syntaxnet/dragnn/runtime/sequence_backend_test.cc
research/syntaxnet/dragnn/runtime/sequence_backend_test.cc
+172
-0
research/syntaxnet/dragnn/runtime/sequence_bulk_dynamic_component.cc
...ntaxnet/dragnn/runtime/sequence_bulk_dynamic_component.cc
+195
-0
research/syntaxnet/dragnn/runtime/sequence_bulk_dynamic_component_test.cc
...et/dragnn/runtime/sequence_bulk_dynamic_component_test.cc
+311
-0
No files found.
Too many changes to show.
To preserve performance only
294 of 294+
files are displayed.
Plain diff
Email patch
research/syntaxnet/dragnn/runtime/network_unit.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/network_unit.h"
#include <vector>
#include "tensorflow/core/lib/strings/str_util.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
string
NetworkUnit
::
GetClassName
(
const
ComponentSpec
&
component_spec
)
{
// The Python registration API is based on (relative) module paths, such as
// "some.module.FooNetwork". Therefore, we discard the module path prefix and
// use only the final segment, which is the subclass name.
const
std
::
vector
<
string
>
segments
=
tensorflow
::
str_util
::
Split
(
component_spec
.
network_unit
().
registered_name
(),
"."
);
CHECK_GT
(
segments
.
size
(),
0
)
<<
"No network unit name for component spec: "
<<
component_spec
.
ShortDebugString
();
return
segments
.
back
();
}
}
// namespace runtime
}
// namespace dragnn
REGISTER_SYNTAXNET_CLASS_REGISTRY
(
"DRAGNN Runtime Network Unit"
,
dragnn
::
runtime
::
NetworkUnit
);
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/network_unit.h
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_NETWORK_UNIT_H_
#define DRAGNN_RUNTIME_NETWORK_UNIT_H_
#include <stddef.h>
#include <string>
#include "dragnn/core/compute_session.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "syntaxnet/registry.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// Interface for network units for sequential inference.
class
NetworkUnit
:
public
RegisterableClass
<
NetworkUnit
>
{
public:
NetworkUnit
(
const
NetworkUnit
&
that
)
=
delete
;
NetworkUnit
&
operator
=
(
const
NetworkUnit
&
that
)
=
delete
;
virtual
~
NetworkUnit
()
=
default
;
// Returns the network unit class name specified in the |component_spec|.
static
string
GetClassName
(
const
ComponentSpec
&
component_spec
);
// Initializes this to the configuration in the |component_spec|. Retrieves
// pre-trained variables from the |variable_store|, which must outlive this.
// Adds layers and local operands to the |network_state_manager|, which must
// be positioned at the current component. Requests SessionState extensions
// from the |extension_manager|. On error, returns non-OK.
virtual
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
=
0
;
// Returns the name of the layer that contains classification logits, or an
// empty string if this does not produce logits. Requires that Initialize()
// was called.
virtual
string
GetLogitsName
()
const
=
0
;
// Evaluates this network unit on the |session_state| and |compute_session|.
// Requires that:
// * The network states in the |session_state| is positioned at the current
// component, which must have at least |step_index|+1 steps.
// * The same component in the |compute_session| must have traversed
// |step_index| transitions.
// * Initialize() was called.
// On error, returns non-OK.
virtual
tensorflow
::
Status
Evaluate
(
size_t
step_index
,
SessionState
*
session_state
,
ComputeSession
*
compute_session
)
const
=
0
;
protected:
NetworkUnit
()
=
default
;
private:
// Helps prevent use of the Create() method; use CreateOrError() instead.
using
RegisterableClass
<
NetworkUnit
>::
Create
;
};
}
// namespace runtime
}
// namespace dragnn
DECLARE_SYNTAXNET_CLASS_REGISTRY
(
"DRAGNN Runtime Network Unit"
,
dragnn
::
runtime
::
NetworkUnit
);
}
// namespace syntaxnet
// Registers a subclass using its class name as a string.
#define DRAGNN_RUNTIME_REGISTER_NETWORK_UNIT(subclass) \
REGISTER_SYNTAXNET_CLASS_COMPONENT( \
::syntaxnet::dragnn::runtime::NetworkUnit, #subclass, subclass)
#endif // DRAGNN_RUNTIME_NETWORK_UNIT_H_
research/syntaxnet/dragnn/runtime/network_unit_base.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/network_unit_base.h"
#include <string.h>
#include "tensorflow/core/lib/core/errors.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Returns the sum of the dimensions of all channels in the |manager|. The
// EmbeddingManager template type should be either FixedEmbeddingManager or
// LinkedEmbeddingManager; note that both share the same API.
template
<
class
EmbeddingManager
>
size_t
SumEmbeddingDimensions
(
const
EmbeddingManager
&
manager
)
{
size_t
sum
=
0
;
for
(
size_t
i
=
0
;
i
<
manager
.
num_channels
();
++
i
)
{
sum
+=
manager
.
embedding_dim
(
i
);
}
return
sum
;
}
// Copies each channel of the |embeddings| into the region starting at |data|.
// Returns a pointer to one past the last element of the copied region. The
// Embeddings type should be FixedEmbeddings or LinkedEmbeddings; note that both
// have the same API.
//
// TODO(googleuser): Try a vectorized copy instead of memcpy(). Unclear whether
// we can do better, though. For one, the memcpy() implementation may already
// be vectorized. Also, while the input embeddings are aligned, the output is
// not; e.g., consider concatenating inputs with dims 7 and 9. This could be
// addressed by requiring that embedding dims are aligned, or by handling the
// unaligned prefix separately.
//
// TODO(googleuser): Consider alternatives for handling fixed feature channels
// with size>1. The least surprising approach is to concatenate the size>1
// embeddings inside FixedEmbeddings, so the channel IDs still correspond to
// positions in the ComponentSpec.fixed_feature list. However, that means the
// same embedding gets copied twice, once there and once here. Conversely, we
// could split the size>1 embeddings into separate channels, eliding a copy
// while obfuscating the channel IDs. IMO, separate channels seem better
// because very few bits of DRAGNN actually access individual channels, and I
// wrote many of those bits.
template
<
class
Embeddings
>
float
*
CopyEmbeddings
(
const
Embeddings
&
embeddings
,
float
*
data
)
{
for
(
size_t
i
=
0
;
i
<
embeddings
.
num_embeddings
();
++
i
)
{
const
Vector
<
float
>
vector
=
embeddings
.
embedding
(
i
);
memcpy
(
data
,
vector
.
data
(),
vector
.
size
()
*
sizeof
(
float
));
data
+=
vector
.
size
();
}
return
data
;
}
}
// namespace
tensorflow
::
Status
NetworkUnitBase
::
InitializeBase
(
bool
use_concatenated_input
,
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
{
use_concatenated_input_
=
use_concatenated_input
;
num_actions_
=
component_spec
.
num_actions
();
TF_RETURN_IF_ERROR
(
fixed_embedding_manager_
.
Reset
(
component_spec
,
variable_store
,
network_state_manager
));
TF_RETURN_IF_ERROR
(
linked_embedding_manager_
.
Reset
(
component_spec
,
variable_store
,
network_state_manager
));
concatenated_input_dim_
=
SumEmbeddingDimensions
(
fixed_embedding_manager_
)
+
SumEmbeddingDimensions
(
linked_embedding_manager_
);
if
(
use_concatenated_input_
)
{
// If there is <= 1 input embedding, then the concatenation is trivial and
// we don't need a local vector; see ConcatenateInput().
const
size_t
num_embeddings
=
fixed_embedding_manager_
.
num_embeddings
()
+
linked_embedding_manager_
.
num_embeddings
();
if
(
num_embeddings
>
1
)
{
TF_RETURN_IF_ERROR
(
network_state_manager
->
AddLocal
(
concatenated_input_dim_
,
&
concatenated_input_handle_
));
}
// Check that all fixed features are embedded.
for
(
size_t
i
=
0
;
i
<
fixed_embedding_manager_
.
num_channels
();
++
i
)
{
if
(
!
fixed_embedding_manager_
.
is_embedded
(
i
))
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Non-embedded fixed features cannot be concatenated"
);
}
}
}
extension_manager
->
GetShared
(
&
fixed_embeddings_handle_
);
extension_manager
->
GetShared
(
&
linked_embeddings_handle_
);
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
NetworkUnitBase
::
EvaluateBase
(
SessionState
*
session_state
,
ComputeSession
*
compute_session
,
Vector
<
float
>
*
concatenated_input
)
const
{
FixedEmbeddings
&
fixed_embeddings
=
session_state
->
extensions
.
Get
(
fixed_embeddings_handle_
);
LinkedEmbeddings
&
linked_embeddings
=
session_state
->
extensions
.
Get
(
linked_embeddings_handle_
);
TF_RETURN_IF_ERROR
(
fixed_embeddings
.
Reset
(
&
fixed_embedding_manager_
,
session_state
->
network_states
,
compute_session
));
TF_RETURN_IF_ERROR
(
linked_embeddings
.
Reset
(
&
linked_embedding_manager_
,
session_state
->
network_states
,
compute_session
));
if
(
use_concatenated_input_
&&
concatenated_input
!=
nullptr
)
{
*
concatenated_input
=
ConcatenateInput
(
session_state
);
}
return
tensorflow
::
Status
::
OK
();
}
Vector
<
float
>
NetworkUnitBase
::
ConcatenateInput
(
SessionState
*
session_state
)
const
{
DCHECK
(
use_concatenated_input_
);
const
FixedEmbeddings
&
fixed_embeddings
=
session_state
->
extensions
.
Get
(
fixed_embeddings_handle_
);
const
LinkedEmbeddings
&
linked_embeddings
=
session_state
->
extensions
.
Get
(
linked_embeddings_handle_
);
const
size_t
num_embeddings
=
fixed_embeddings
.
num_embeddings
()
+
linked_embeddings
.
num_embeddings
();
// Special cases where no actual concatenation is required.
if
(
num_embeddings
==
0
)
return
{};
if
(
num_embeddings
==
1
)
{
return
fixed_embeddings
.
num_embeddings
()
>
0
?
fixed_embeddings
.
embedding
(
0
)
:
linked_embeddings
.
embedding
(
0
);
}
// General case; concatenate into a local vector. The ordering of embeddings
// must be exactly the same as in the Python codebase, which is:
// 1. Fixed embeddings before linked embeddings (see get_input_tensor() in
// network_units.py).
// 2. In each type, ordered as listed in ComponentSpec.fixed/linked_feature
// (see DynamicComponentBuilder._feedforward_unit() in component.py).
//
// Since FixedEmbeddings and LinkedEmbeddings already follow the order defined
// in the ComponentSpec, it suffices to append each fixed embedding, then each
// linked embedding.
const
MutableVector
<
float
>
concatenation
=
session_state
->
network_states
.
GetLocal
(
concatenated_input_handle_
);
float
*
data
=
concatenation
.
data
();
data
=
CopyEmbeddings
(
fixed_embeddings
,
data
);
data
=
CopyEmbeddings
(
linked_embeddings
,
data
);
DCHECK_EQ
(
data
,
concatenation
.
end
());
return
Vector
<
float
>
(
concatenation
);
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/network_unit_base.h
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_NETWORK_UNIT_BASE_H_
#define DRAGNN_RUNTIME_NETWORK_UNIT_BASE_H_
#include <stddef.h>
#include <utility>
#include <vector>
#include "dragnn/core/compute_session.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/fixed_embeddings.h"
#include "dragnn/runtime/linked_embeddings.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/network_unit.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/variable_store.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// A base class for network units that provides common functionality, analogous
// to NetworkUnitInterface.__init__() in network_units.py. Specifically, this
// class manages and builds input embeddings and, as an convenience, optionally
// concatenates the input embeddings into a single vector.
//
// Since recurrent layers are both outputs and inputs, they complicate network
// unit initialization. In particular, the linked embeddings cannot be set up
// until the charateristics of all recurrently-accessible layers are known. On
// the other hand, some layers cannot be initialized until all inputs, including
// the linked embeddings, are set up. For example, the IdentityNetwork outputs
// a layer whose dimension is the sum of all input dimensions.
//
// To accommodate recurrent layers, network unit initialization is organized
// into three phases:
// 1. (Subclass) Initialize all recurrently-accessible layers.
// 2. (This class) Initialize embedding managers and other common state.
// 3. (Subclass) Initialize any non-recurrent layers.
//
// Concretely, the subclass's Initialize() should first add recurrent layers,
// then call InitializeBase(), and finally finish initializing. Evaluation is
// simpler: the subclass's Evaluate() may call EvaluateBase() at any time.
//
// Note: Network unit initialization is similarly interleaved between base and
// subclasses in the Python codebase; see NetworkUnitInterface.get_layer_size()
// and the "init_layers" argument to NetworkUnitInterface.__init__().
class
NetworkUnitBase
:
public
NetworkUnit
{
public:
// Initializes common state as configured in the |component_spec|. Retrieves
// pre-trained embedding matrices from the |variable_store|. Looks up linked
// embeddings in the |network_state_manager|, which must contain all recurrent
// layers. Requests any required extensions from the |extension_manager|. If
// |use_concatenated_input| is true, prepares to concatenate input embeddings
// in EvaluateBase(). On error, returns non-OK.
tensorflow
::
Status
InitializeBase
(
bool
use_concatenated_input
,
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
);
// Resets the fixed and linked embeddings in the |session_state| using its
// network states and the |compute_session|. Requires that InitializeBase()
// was called. If this was prepared for concatenation (see InitializeBase())
// and if |concatenated_input| is non-null, points it at the concatenation of
// the fixed and linked embeddings. Otherwise, no concatenation occurs. On
// error, returns non-OK.
tensorflow
::
Status
EvaluateBase
(
SessionState
*
session_state
,
ComputeSession
*
compute_session
,
Vector
<
float
>
*
concatenated_input
)
const
;
// Accessors. All require that InitializeBase() was called.
const
FixedEmbeddingManager
&
fixed_embedding_manager
()
const
;
const
LinkedEmbeddingManager
&
linked_embedding_manager
()
const
;
size_t
num_actions
()
const
{
return
num_actions_
;
}
size_t
concatenated_input_dim
()
const
{
return
concatenated_input_dim_
;
}
private:
// Returns the concatenation of the fixed and linked embeddings in the
// |seesion_state|. Requires that |use_concatenated_input_| is true.
Vector
<
float
>
ConcatenateInput
(
SessionState
*
session_state
)
const
;
// Managers for fixed and linked embeddings in this component.
FixedEmbeddingManager
fixed_embedding_manager_
;
LinkedEmbeddingManager
linked_embedding_manager_
;
// Fixed and linked embeddings.
SharedExtensionHandle
<
FixedEmbeddings
>
fixed_embeddings_handle_
;
SharedExtensionHandle
<
LinkedEmbeddings
>
linked_embeddings_handle_
;
// Number of actions supported by the transition system.
size_t
num_actions_
=
0
;
// Sum of dimensions of all fixed and linked embeddings.
size_t
concatenated_input_dim_
=
0
;
// Whether to concatenate the input embeddings.
bool
use_concatenated_input_
=
false
;
// Handle of the vector that holds the concatenated input, or invalid if no
// concatenation is required.
LocalVectorHandle
<
float
>
concatenated_input_handle_
;
};
// Implementation details below.
inline
const
FixedEmbeddingManager
&
NetworkUnitBase
::
fixed_embedding_manager
()
const
{
return
fixed_embedding_manager_
;
}
inline
const
LinkedEmbeddingManager
&
NetworkUnitBase
::
linked_embedding_manager
()
const
{
return
linked_embedding_manager_
;
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_NETWORK_UNIT_BASE_H_
research/syntaxnet/dragnn/runtime/network_unit_base_test.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/network_unit_base.h"
#include <stddef.h>
#include <memory>
#include <string>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/fixed_embeddings.h"
#include "dragnn/runtime/linked_embeddings.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include <gmock/gmock.h>
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
using
::
testing
::
_
;
using
::
testing
::
Invoke
;
using
::
testing
::
Return
;
// Dimensions of the layers in the network.
static
constexpr
size_t
kPreviousDim
=
77
;
static
constexpr
size_t
kRecurrentDim
=
123
;
// Contents of the layers in the network.
static
constexpr
float
kPreviousValue
=
-
2.75
;
static
constexpr
float
kRecurrentValue
=
6.25
;
// Number of steps taken in each component.
static
constexpr
size_t
kNumSteps
=
10
;
// A trivial network unit that exposes the concatenated inputs. Note that
// NetworkUnitBase does not override the interface methods, so we need a
// concrete subclass for testing.
class
FooNetwork
:
public
NetworkUnitBase
{
public:
void
RequestConcatenation
()
{
request_concatenation_
=
true
;
}
void
ProvideConcatenatedInput
()
{
provide_concatenated_input_
=
true
;
}
Vector
<
float
>
concatenated_input
()
const
{
return
concatenated_input_
;
}
// Implements NetworkUnit.
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
override
{
TF_RETURN_IF_ERROR
(
network_state_manager
->
AddLayer
(
"recurrent_layer"
,
kRecurrentDim
,
&
recurrent_handle_
));
return
InitializeBase
(
request_concatenation_
,
component_spec
,
variable_store
,
network_state_manager
,
extension_manager
);
}
string
GetLogitsName
()
const
override
{
return
""
;
}
tensorflow
::
Status
Evaluate
(
size_t
unused_step_index
,
SessionState
*
session_state
,
ComputeSession
*
compute_session
)
const
override
{
return
EvaluateBase
(
session_state
,
compute_session
,
provide_concatenated_input_
?
&
concatenated_input_
:
nullptr
);
}
private:
bool
request_concatenation_
=
false
;
bool
provide_concatenated_input_
=
false
;
LayerHandle
<
float
>
recurrent_handle_
;
mutable
Vector
<
float
>
concatenated_input_
;
// Evaluate() sets this
};
class
NetworkUnitBaseTest
:
public
NetworkTestBase
{
protected:
// Initializes the |network_unit_| based on the |component_spec_text| and
// evaluates it. On error, returns non-OK.
tensorflow
::
Status
Run
(
const
string
&
component_spec_text
)
{
ComponentSpec
component_spec
;
CHECK
(
TextFormat
::
ParseFromString
(
component_spec_text
,
&
component_spec
));
component_spec
.
set_name
(
kTestComponentName
);
AddComponent
(
"previous_component"
);
AddLayer
(
"previous_layer"
,
kPreviousDim
);
AddComponent
(
kTestComponentName
);
TF_RETURN_IF_ERROR
(
network_unit_
.
Initialize
(
component_spec
,
&
variable_store_
,
&
network_state_manager_
,
&
extension_manager_
));
// Create and populate the network states.
network_states_
.
Reset
(
&
network_state_manager_
);
StartComponent
(
kNumSteps
);
StartComponent
(
kNumSteps
);
FillLayer
(
"previous_component"
,
"previous_layer"
,
kPreviousValue
);
FillLayer
(
kTestComponentName
,
"recurrent_layer"
,
kRecurrentValue
);
session_state_
.
extensions
.
Reset
(
&
extension_manager_
);
// Neither FooNetwork nor NetworkUnitBase look at the step index, so use an
// arbitrary value.
return
network_unit_
.
Evaluate
(
0
,
&
session_state_
,
&
compute_session_
);
}
FooNetwork
network_unit_
;
std
::
vector
<
std
::
vector
<
float
>>
concatenated_inputs_
;
};
// Tests that NetworkUnitBase produces an empty vector when concatenating and
// there are no input embeddings.
TEST_F
(
NetworkUnitBaseTest
,
ConcatenateNoInputs
)
{
network_unit_
.
RequestConcatenation
();
network_unit_
.
ProvideConcatenatedInput
();
TF_ASSERT_OK
(
Run
(
""
));
EXPECT_EQ
(
network_unit_
.
fixed_embedding_manager
().
num_channels
(),
0
);
EXPECT_EQ
(
network_unit_
.
linked_embedding_manager
().
num_channels
(),
0
);
EXPECT_EQ
(
network_unit_
.
num_actions
(),
0
);
EXPECT_EQ
(
network_unit_
.
concatenated_input_dim
(),
0
);
EXPECT_TRUE
(
network_unit_
.
concatenated_input
().
empty
());
}
// Tests that NetworkUnitBase produces a copy of the single input embedding when
// concatenating a single fixed channel.
TEST_F
(
NetworkUnitBaseTest
,
ConcatenateOneFixedChannel
)
{
const
float
kEmbedding
=
1.5
;
const
float
kFeature
=
0.5
;
const
size_t
kDim
=
13
;
const
string
kSpec
=
R"(num_actions: 42
fixed_feature {
vocabulary_size: 11
embedding_dim: 13
size: 1
})"
;
AddFixedEmbeddingMatrix
(
0
,
11
,
kDim
,
kEmbedding
);
EXPECT_CALL
(
compute_session_
,
GetInputFeatures
(
_
,
_
,
_
,
_
,
_
))
.
WillOnce
(
Invoke
(
ExtractFeatures
(
0
,
{{
1
,
kFeature
}})));
const
float
kValue
=
kEmbedding
*
kFeature
;
network_unit_
.
RequestConcatenation
();
network_unit_
.
ProvideConcatenatedInput
();
TF_ASSERT_OK
(
Run
(
kSpec
));
EXPECT_EQ
(
network_unit_
.
fixed_embedding_manager
().
num_channels
(),
1
);
EXPECT_EQ
(
network_unit_
.
linked_embedding_manager
().
num_channels
(),
0
);
EXPECT_EQ
(
network_unit_
.
num_actions
(),
42
);
EXPECT_EQ
(
network_unit_
.
concatenated_input_dim
(),
kDim
);
ExpectVector
(
network_unit_
.
concatenated_input
(),
network_unit_
.
concatenated_input_dim
(),
kValue
);
}
// Tests that NetworkUnitBase does not concatenate if concatenation is requested
// and the concatenated input vector is not provided.
TEST_F
(
NetworkUnitBaseTest
,
ConcatenatedInputVectorNotProvided
)
{
const
float
kEmbedding
=
1.5
;
const
float
kFeature
=
0.5
;
const
size_t
kDim
=
13
;
const
string
kSpec
=
R"(num_actions: 37
fixed_feature {
vocabulary_size: 11
embedding_dim: 13
size: 1
})"
;
AddFixedEmbeddingMatrix
(
0
,
11
,
kDim
,
kEmbedding
);
EXPECT_CALL
(
compute_session_
,
GetInputFeatures
(
_
,
_
,
_
,
_
,
_
))
.
WillOnce
(
Invoke
(
ExtractFeatures
(
0
,
{{
1
,
kFeature
}})));
network_unit_
.
RequestConcatenation
();
TF_ASSERT_OK
(
Run
(
kSpec
));
// Embedding managers and other config is set up properly.
EXPECT_EQ
(
network_unit_
.
fixed_embedding_manager
().
num_channels
(),
1
);
EXPECT_EQ
(
network_unit_
.
linked_embedding_manager
().
num_channels
(),
0
);
EXPECT_EQ
(
network_unit_
.
num_actions
(),
37
);
EXPECT_EQ
(
network_unit_
.
concatenated_input_dim
(),
kDim
);
// But the concatenation was not performed.
EXPECT_TRUE
(
network_unit_
.
concatenated_input
().
empty
());
}
// As above, but with the converse condition: does not request concatenation,
// but does provide the concatenated input vector.
TEST_F
(
NetworkUnitBaseTest
,
ConcatenationNotRequested
)
{
const
float
kEmbedding
=
1.5
;
const
float
kFeature
=
0.5
;
const
size_t
kDim
=
13
;
const
string
kSpec
=
R"(num_actions: 31
fixed_feature {
vocabulary_size: 11
embedding_dim: 13
size: 1
})"
;
AddFixedEmbeddingMatrix
(
0
,
11
,
kDim
,
kEmbedding
);
EXPECT_CALL
(
compute_session_
,
GetInputFeatures
(
_
,
_
,
_
,
_
,
_
))
.
WillOnce
(
Invoke
(
ExtractFeatures
(
0
,
{{
1
,
kFeature
}})));
network_unit_
.
ProvideConcatenatedInput
();
TF_ASSERT_OK
(
Run
(
kSpec
));
// Embedding managers and other config is set up properly.
EXPECT_EQ
(
network_unit_
.
fixed_embedding_manager
().
num_channels
(),
1
);
EXPECT_EQ
(
network_unit_
.
linked_embedding_manager
().
num_channels
(),
0
);
EXPECT_EQ
(
network_unit_
.
num_actions
(),
31
);
EXPECT_EQ
(
network_unit_
.
concatenated_input_dim
(),
kDim
);
// But the concatenation was not performed.
EXPECT_TRUE
(
network_unit_
.
concatenated_input
().
empty
());
}
// Tests that NetworkUnitBase produces a copy of the single input embedding when
// concatenating a single linked channel.
TEST_F
(
NetworkUnitBaseTest
,
ConcatenateOneLinkedChannel
)
{
const
string
kSpec
=
R"(num_actions: 37
linked_feature {
embedding_dim: -1
source_component: 'previous_component'
source_layer: 'previous_layer'
size: 1
})"
;
EXPECT_CALL
(
compute_session_
,
GetTranslatedLinkFeatures
(
_
,
_
))
.
WillOnce
(
Invoke
(
ExtractLinks
(
0
,
{
"step_idx: 5"
})));
EXPECT_CALL
(
compute_session_
,
SourceComponentBeamSize
(
_
,
_
))
.
WillRepeatedly
(
Return
(
1
));
network_unit_
.
RequestConcatenation
();
network_unit_
.
ProvideConcatenatedInput
();
TF_ASSERT_OK
(
Run
(
kSpec
));
EXPECT_EQ
(
network_unit_
.
fixed_embedding_manager
().
num_channels
(),
0
);
EXPECT_EQ
(
network_unit_
.
linked_embedding_manager
().
num_channels
(),
1
);
EXPECT_EQ
(
network_unit_
.
num_actions
(),
37
);
EXPECT_EQ
(
network_unit_
.
concatenated_input_dim
(),
kPreviousDim
);
ExpectVector
(
network_unit_
.
concatenated_input
(),
network_unit_
.
concatenated_input_dim
(),
kPreviousValue
);
}
// Tests that NetworkUnitBase concatenates a fixed and linked channel in that
// order.
TEST_F
(
NetworkUnitBaseTest
,
ConcatenateOneChannelOfEachType
)
{
const
float
kEmbedding
=
1.25
;
const
float
kFeature
=
0.75
;
const
size_t
kFixedDim
=
13
;
const
string
kSpec
=
R"(num_actions: 77
fixed_feature {
vocabulary_size: 11
embedding_dim: 13
size: 1
}
linked_feature {
embedding_dim: -1
source_component: 'previous_component'
source_layer: 'previous_layer'
size: 1
})"
;
AddFixedEmbeddingMatrix
(
0
,
11
,
kFixedDim
,
kEmbedding
);
EXPECT_CALL
(
compute_session_
,
GetInputFeatures
(
_
,
_
,
_
,
_
,
_
))
.
WillOnce
(
Invoke
(
ExtractFeatures
(
0
,
{{
1
,
kFeature
}})));
const
float
kFixedValue
=
kEmbedding
*
kFeature
;
EXPECT_CALL
(
compute_session_
,
GetTranslatedLinkFeatures
(
_
,
_
))
.
WillOnce
(
Invoke
(
ExtractLinks
(
0
,
{
"step_idx: 5"
})));
EXPECT_CALL
(
compute_session_
,
SourceComponentBeamSize
(
_
,
_
))
.
WillRepeatedly
(
Return
(
1
));
network_unit_
.
RequestConcatenation
();
network_unit_
.
ProvideConcatenatedInput
();
TF_ASSERT_OK
(
Run
(
kSpec
));
EXPECT_EQ
(
network_unit_
.
fixed_embedding_manager
().
num_channels
(),
1
);
EXPECT_EQ
(
network_unit_
.
linked_embedding_manager
().
num_channels
(),
1
);
EXPECT_EQ
(
network_unit_
.
num_actions
(),
77
);
EXPECT_EQ
(
network_unit_
.
concatenated_input_dim
(),
kFixedDim
+
kPreviousDim
);
// Check that each sub-segment is equal to one of the input embeddings.
const
Vector
<
float
>
input
=
network_unit_
.
concatenated_input
();
EXPECT_EQ
(
input
.
size
(),
network_unit_
.
concatenated_input_dim
());
size_t
index
=
0
;
size_t
end
=
kFixedDim
;
for
(;
index
<
end
;
++
index
)
EXPECT_EQ
(
input
[
index
],
kFixedValue
);
end
+=
kPreviousDim
;
for
(;
index
<
end
;
++
index
)
EXPECT_EQ
(
input
[
index
],
kPreviousValue
);
}
// Tests that NetworkUnitBase produces a properly-ordered concatenation of
// multiple fixed and linked channels, including a recurrent channel.
TEST_F
(
NetworkUnitBaseTest
,
ConcatenateMultipleChannelsOfEachType
)
{
const
float
kEmbedding0
=
1.25
;
const
float
kEmbedding1
=
-
0.125
;
const
float
kFeature0
=
0.75
;
const
float
kFeature1
=
-
2.5
;
const
size_t
kFixedDim0
=
13
;
const
size_t
kFixedDim1
=
19
;
const
string
kSpec
=
R"(num_actions: 99
fixed_feature {
vocabulary_size: 11
embedding_dim: 13
size: 1
}
fixed_feature {
vocabulary_size: 17
embedding_dim: 19
size: 1
}
linked_feature {
embedding_dim: -1
source_component: 'previous_component'
source_layer: 'previous_layer'
size: 1
}
linked_feature {
embedding_dim: -1
source_component: 'test_component'
source_layer: 'recurrent_layer'
size: 1
})"
;
AddFixedEmbeddingMatrix
(
0
,
11
,
kFixedDim0
,
kEmbedding0
);
AddFixedEmbeddingMatrix
(
1
,
17
,
kFixedDim1
,
kEmbedding1
);
EXPECT_CALL
(
compute_session_
,
GetInputFeatures
(
_
,
_
,
_
,
_
,
_
))
.
WillOnce
(
Invoke
(
ExtractFeatures
(
0
,
{{
1
,
kFeature0
}})))
.
WillOnce
(
Invoke
(
ExtractFeatures
(
1
,
{{
1
,
kFeature1
}})));
const
float
kFixedValue0
=
kEmbedding0
*
kFeature0
;
const
float
kFixedValue1
=
kEmbedding1
*
kFeature1
;
EXPECT_CALL
(
compute_session_
,
GetTranslatedLinkFeatures
(
_
,
_
))
.
WillOnce
(
Invoke
(
ExtractLinks
(
0
,
{
"step_idx: 5"
})))
.
WillOnce
(
Invoke
(
ExtractLinks
(
1
,
{
"step_idx: 6"
})));
EXPECT_CALL
(
compute_session_
,
SourceComponentBeamSize
(
_
,
_
))
.
WillRepeatedly
(
Return
(
1
));
network_unit_
.
RequestConcatenation
();
network_unit_
.
ProvideConcatenatedInput
();
TF_ASSERT_OK
(
Run
(
kSpec
));
EXPECT_EQ
(
network_unit_
.
fixed_embedding_manager
().
num_channels
(),
2
);
EXPECT_EQ
(
network_unit_
.
linked_embedding_manager
().
num_channels
(),
2
);
EXPECT_EQ
(
network_unit_
.
num_actions
(),
99
);
EXPECT_EQ
(
network_unit_
.
concatenated_input_dim
(),
kFixedDim0
+
kFixedDim1
+
kPreviousDim
+
kRecurrentDim
);
// Check that each sub-segment is equal to one of the input embeddings. For
// compatibility with the Python codebase, fixed channels must appear before
// linked channels, and among each type order follows the ComponentSpec.
const
Vector
<
float
>
input
=
network_unit_
.
concatenated_input
();
EXPECT_EQ
(
input
.
size
(),
network_unit_
.
concatenated_input_dim
());
size_t
index
=
0
;
size_t
end
=
kFixedDim0
;
for
(;
index
<
end
;
++
index
)
EXPECT_EQ
(
input
[
index
],
kFixedValue0
);
end
+=
kFixedDim1
;
for
(;
index
<
end
;
++
index
)
EXPECT_EQ
(
input
[
index
],
kFixedValue1
);
end
+=
kPreviousDim
;
for
(;
index
<
end
;
++
index
)
EXPECT_EQ
(
input
[
index
],
kPreviousValue
);
end
+=
kRecurrentDim
;
for
(;
index
<
end
;
++
index
)
EXPECT_EQ
(
input
[
index
],
kRecurrentValue
);
}
// Tests that NetworkUnitBase refuses to concatenate if there are non-embedded
// fixed embeddings.
TEST_F
(
NetworkUnitBaseTest
,
CannotConcatenateNonEmbeddedFixedFeatures
)
{
const
string
kBadSpec
=
R"(fixed_feature {
embedding_dim: -1
size: 1
})"
;
network_unit_
.
RequestConcatenation
();
network_unit_
.
ProvideConcatenatedInput
();
EXPECT_THAT
(
Run
(
kBadSpec
),
test
::
IsErrorWithSubstr
(
"Non-embedded fixed features cannot be concatenated"
));
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/network_unit_test.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/network_unit.h"
#include <memory>
#include <string>
#include "dragnn/core/compute_session.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Expects that the two pointers have the same address.
void
ExpectSameAddress
(
const
void
*
pointer1
,
const
void
*
pointer2
)
{
EXPECT_EQ
(
pointer1
,
pointer2
);
}
// A trivial implementation for tests.
class
FooNetwork
:
public
NetworkUnit
{
public:
// Implements NetworkUnit.
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
override
{
return
tensorflow
::
Status
::
OK
();
}
string
GetLogitsName
()
const
override
{
return
"foo_logits"
;
}
tensorflow
::
Status
Evaluate
(
size_t
step_index
,
SessionState
*
session_state
,
ComputeSession
*
compute_session
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_NETWORK_UNIT
(
FooNetwork
);
// Tests that a human-friendly error is produced for empty network units.
TEST
(
NetworkUnitTest
,
GetClassNameDegenerateName
)
{
ComponentSpec
component_spec
;
EXPECT_DEATH
(
NetworkUnit
::
GetClassName
(
component_spec
),
"No network unit name for component spec"
);
}
// Tests that NetworkUnit::GetClassName() resolves names properly.
TEST
(
NetworkUnitTest
,
GetClassName
)
{
for
(
const
string
&
registered_name
:
{
"FooNetwork"
,
"module.FooNetwork"
,
"some.long.path.to.module.FooNetwork"
})
{
ComponentSpec
component_spec
;
component_spec
.
mutable_network_unit
()
->
set_registered_name
(
registered_name
);
EXPECT_EQ
(
NetworkUnit
::
GetClassName
(
component_spec
),
"FooNetwork"
);
}
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/operands.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/operands.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
OperandHandle
OperandManager
::
Add
(
const
OperandSpec
&
spec
)
{
const
size_t
index
=
specs_
.
size
();
specs_
.
push_back
(
spec
);
switch
(
spec
.
type
)
{
case
OperandType
::
kSingular
:
handle_index_to_typed_index_
.
push_back
(
singular_spans_
.
size
());
singular_spans_
.
emplace_back
(
singular_size_
,
spec
.
size
);
singular_size_
+=
PadToAlignment
(
spec
.
size
);
break
;
case
OperandType
::
kStepwise
:
handle_index_to_typed_index_
.
push_back
(
stepwise_spans_
.
size
());
stepwise_spans_
.
emplace_back
(
stepwise_stride_
,
spec
.
size
);
stepwise_stride_
+=
PadToAlignment
(
spec
.
size
);
break
;
case
OperandType
::
kPairwise
:
handle_index_to_typed_index_
.
push_back
(
pairwise_sizes_
.
size
());
pairwise_sizes_
.
push_back
(
spec
.
size
);
break
;
}
return
OperandHandle
(
index
);
}
void
Operands
::
Reset
(
const
OperandManager
*
manager
,
size_t
pre_allocate_num_steps
)
{
manager_
=
manager
;
handle_index_to_typed_index_
=
manager_
->
handle_index_to_typed_index_
;
stepwise_spans_
=
manager_
->
stepwise_spans_
;
stepwise_stride_
=
manager_
->
stepwise_stride_
;
pairwise_sizes_
=
manager_
->
pairwise_sizes_
;
// Allocate and parcel out singular operands.
singular_operands_
.
clear
();
singular_operands_
.
reserve
(
manager_
->
singular_spans_
.
size
());
singular_array_
.
Reserve
(
manager_
->
singular_size_
);
char
*
data
=
singular_array_
.
view
().
data
();
for
(
const
auto
&
span
:
manager_
->
singular_spans_
)
{
singular_operands_
.
push_back
(
MutableAlignedView
(
data
+
span
.
first
,
span
.
second
));
}
// Pre-allocate and parcel out stepwise operands.
stepwise_operands_
.
clear
();
stepwise_operands_
.
reserve
(
stepwise_spans_
.
size
());
stepwise_array_
.
Reserve
(
stepwise_stride_
*
pre_allocate_num_steps
);
data
=
stepwise_array_
.
view
().
data
();
for
(
const
auto
&
span
:
stepwise_spans_
)
{
stepwise_operands_
.
push_back
(
MutableAlignedArea
(
data
+
span
.
first
,
0
,
span
.
second
,
stepwise_stride_
));
}
// Create empty pairwise operands.
pairwise_operands_
.
clear
();
pairwise_operands_
.
resize
(
pairwise_sizes_
.
size
());
}
void
Operands
::
AddSteps
(
size_t
num_steps
)
{
AddStepwiseSteps
(
num_steps
);
AddPairwiseSteps
(
num_steps
);
}
void
Operands
::
AddStepwiseSteps
(
size_t
num_steps
)
{
if
(
stepwise_operands_
.
empty
())
return
;
// Make room for the new steps.
const
size_t
new_num_views
=
stepwise_operands_
[
0
].
num_views_
+
num_steps
;
const
bool
actually_reallocated
=
stepwise_array_
.
Resize
(
new_num_views
*
stepwise_stride_
);
// Update the base pointers for stepwise operands, if changed.
if
(
actually_reallocated
)
{
char
*
data
=
stepwise_array_
.
view
().
data
();
for
(
size_t
i
=
0
;
i
<
stepwise_operands_
.
size
();
++
i
)
{
stepwise_operands_
[
i
].
data_
=
data
+
stepwise_spans_
[
i
].
first
;
}
}
// Update the number of views in each stepwise operand.
for
(
MutableAlignedArea
&
operand
:
stepwise_operands_
)
{
operand
.
num_views_
=
new_num_views
;
}
}
void
Operands
::
AddPairwiseSteps
(
size_t
num_steps
)
{
if
(
pairwise_operands_
.
empty
())
return
;
const
size_t
new_num_steps
=
pairwise_operands_
[
0
].
num_views_
+
num_steps
;
// Set dimensions for each pairwise operand and accumulate their total stride.
size_t
new_stride
=
0
;
for
(
size_t
i
=
0
;
i
<
pairwise_operands_
.
size
();
++
i
)
{
const
size_t
new_view_size
=
new_num_steps
*
pairwise_sizes_
[
i
];
pairwise_operands_
[
i
].
num_views_
=
new_num_steps
;
pairwise_operands_
[
i
].
view_size_
=
new_view_size
;
new_stride
+=
PadToAlignment
(
new_view_size
);
}
// Note that Reset() does not preserve the existing array and its contents.
// Although preserving existing data would be nice, it is complex because
// pairwise operands grow in both dimensions. In addition, users should be
// allocating pairwise operands in one shot for speed reasons, in which case
// there is no existing data anyways.
pairwise_array_
.
Reset
(
new_num_steps
*
new_stride
);
// Set the new base pointer and stride on each pairwise operand.
char
*
data
=
pairwise_array_
.
view
().
data
();
for
(
MutableAlignedArea
&
operand
:
pairwise_operands_
)
{
operand
.
data_
=
data
;
operand
.
view_stride_
=
new_stride
;
data
+=
PadToAlignment
(
operand
.
view_size_
);
}
DCHECK_EQ
(
data
-
pairwise_array_
.
view
().
data
(),
new_stride
);
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/operands.h
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Utils for declaring and allocating operands. An operand is made up of
// aligned byte arrays, and can be used as an input, output, or intermediate
// value in some computation.
#ifndef DRAGNN_RUNTIME_OPERANDS_H_
#define DRAGNN_RUNTIME_OPERANDS_H_
#include <stddef.h>
#include <stdint.h>
#include <string>
#include <utility>
#include <vector>
#include "dragnn/runtime/alignment.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// Possible types of operands.
enum
class
OperandType
{
// A single byte array. For example, an intermediate value that is computed
// once per transition step. Since it is not an output, the same storage
// could be reused across all steps.
kSingular
,
// A sequence of identically-sized byte arrays, one per transition step. For
// example, a layer containing one activation vector per step.
kStepwise
,
// A grid with one byte array for each pair of transition steps, including
// self pairings. The byte arrays are grouped and concatenated in "rows",
// forming one byte array per step. For example, if there are N steps and D
// bytes per pair, the operand would have N arrays of size N*D bytes. In a
// basic attention model with one "similarity" between pairs of steps, one
// might use a pairwise operand with D=sizeof(float). For best performance,
// use Operands::AddSteps() to allocate all steps at once when working with
// pairwise operands.
kPairwise
,
};
// A specification of a operand.
struct
OperandSpec
{
// Creates a trivial specification.
OperandSpec
()
=
default
;
// Creates a specification with the |type| and |size|.
OperandSpec
(
OperandType
type
,
size_t
size
)
:
type
(
type
),
size
(
size
)
{}
// Type of the operand.
OperandType
type
=
OperandType
::
kSingular
;
// Size of each aligned byte array in the operand.
size_t
size
=
0
;
};
// An opaque handle to an operand.
class
OperandHandle
;
// A class that manages a set of operand specifications and associates each
// operand with a handle. Operand contents can be retrieved using these
// handles; see Operands below.
class
OperandManager
{
public:
// Creates an empty manager.
OperandManager
()
=
default
;
// Adds an operand configured according to the |spec| and returns its handle.
OperandHandle
Add
(
const
OperandSpec
&
spec
);
// Accessors.
const
OperandSpec
&
spec
(
OperandHandle
handle
)
const
;
private:
friend
class
Operands
;
// Specification of each operand.
std
::
vector
<
OperandSpec
>
specs_
;
// Mapping from the handle index of an operand to its index amongst operands
// of the same type.
std
::
vector
<
size_t
>
handle_index_to_typed_index_
;
// Span of each singular operand, as a (start-offset,size) pair, relative to
// the byte array containing all singular operands.
std
::
vector
<
std
::
pair
<
size_t
,
size_t
>>
singular_spans_
;
// Span of each stepwise operand, as a (start-offset,size) pair, relative to
// the byte array for each step.
std
::
vector
<
std
::
pair
<
size_t
,
size_t
>>
stepwise_spans_
;
// Size of each pairwise operand.
std
::
vector
<
size_t
>
pairwise_sizes_
;
// Number of bytes used by all singular operands, including alignment padding.
size_t
singular_size_
=
0
;
// Number of bytes used by all stepwise operands on each step, including
// alignment padding.
size_t
stepwise_stride_
=
0
;
};
// A set of operands. The structure of the operands is configured by an
// OperandManager, and operand contents can be accessed using the handles
// produced by the manager.
//
// Multiple Operands instances can share the same OperandManager. In addition,
// an Operands instance can be reused by repeatedly Reset()-ing it, potentially
// with different OperandManagers. Such reuse can reduce allocation overhead.
class
Operands
{
public:
// Creates an empty set.
Operands
()
=
default
;
// Resets this to the operands defined by the |manager|. The |manager| must
// live until this is destroyed or Reset() again, and should not be modified
// during that time. Stepwise and pairwise operands start with 0 steps; use
// AddStep() to extend them. Pre-allocates stepwise operands so that they
// will not be reallocated during the first |pre_allocate_num_steps| calls to
// AddStep(). Invalidates all previously-returned operands.
void
Reset
(
const
OperandManager
*
manager
,
size_t
pre_allocate_num_steps
);
// Extends stepwise and pairwise operands by one or more steps. Requires that
// Reset() was called. Invalidates any previously-returned views of stepwise
// and pairwise operands. Preserves data for pre-existing steps of stepwise
// operands, but not for pre-existing pairwise operands. In general, pairwise
// operands should be allocated in one shot, not incrementally.
void
AddStep
()
{
AddSteps
(
1
);
}
void
AddSteps
(
size_t
num_steps
);
// Returns the singular operand associated with the |handle|. The returned
// view is invalidated by Reset().
MutableAlignedView
GetSingular
(
OperandHandle
handle
)
const
;
// Returns the stepwise operand associated with the |handle|. The returned
// area is invalidated by Reset() and AddStep().
MutableAlignedArea
GetStepwise
(
OperandHandle
handle
)
const
;
// Returns the pairwise operand associated with the |handle|. The returned
// area is invalidated by Reset() and AddStep().
MutableAlignedArea
GetPairwise
(
OperandHandle
handle
)
const
;
private:
// Extends stepwise operands only; see AddSteps().
void
AddStepwiseSteps
(
size_t
num_steps
);
// Extends pairwise operands only; see AddSteps().
void
AddPairwiseSteps
(
size_t
num_steps
);
// Manager of the operands in this set.
const
OperandManager
*
manager_
=
nullptr
;
// Cached members from the |manager_|.
tensorflow
::
gtl
::
ArraySlice
<
size_t
>
handle_index_to_typed_index_
;
tensorflow
::
gtl
::
ArraySlice
<
std
::
pair
<
size_t
,
size_t
>>
stepwise_spans_
;
size_t
stepwise_stride_
=
0
;
tensorflow
::
gtl
::
ArraySlice
<
size_t
>
pairwise_sizes_
;
// Byte arrays holding operands of each type. Storage is separated because
// each type grows differently with the number of steps.
UniqueAlignedArray
singular_array_
;
UniqueAlignedArray
stepwise_array_
;
UniqueAlignedArray
pairwise_array_
;
// Lists of operands of each type.
std
::
vector
<
MutableAlignedView
>
singular_operands_
;
std
::
vector
<
MutableAlignedArea
>
stepwise_operands_
;
std
::
vector
<
MutableAlignedArea
>
pairwise_operands_
;
};
// Implementation details below.
// An opaque handle to an operand.
class
OperandHandle
{
public:
// Creates an invalid handle.
OperandHandle
()
=
default
;
private:
friend
class
OperandManager
;
friend
class
Operands
;
// Creates a handle that points to the |index|.
explicit
OperandHandle
(
size_t
index
)
:
index_
(
index
)
{}
// Index of the operand in its manager.
size_t
index_
=
SIZE_MAX
;
};
inline
const
OperandSpec
&
OperandManager
::
spec
(
OperandHandle
handle
)
const
{
return
specs_
[
handle
.
index_
];
}
inline
MutableAlignedView
Operands
::
GetSingular
(
OperandHandle
handle
)
const
{
DCHECK
(
manager_
->
spec
(
handle
).
type
==
OperandType
::
kSingular
)
<<
"Actual type: "
<<
static_cast
<
int
>
(
manager_
->
spec
(
handle
).
type
);
DCHECK_LE
(
handle
.
index_
,
handle_index_to_typed_index_
.
size
());
return
singular_operands_
[
handle_index_to_typed_index_
[
handle
.
index_
]];
}
inline
MutableAlignedArea
Operands
::
GetStepwise
(
OperandHandle
handle
)
const
{
DCHECK
(
manager_
->
spec
(
handle
).
type
==
OperandType
::
kStepwise
)
<<
"Actual type: "
<<
static_cast
<
int
>
(
manager_
->
spec
(
handle
).
type
);
DCHECK_LE
(
handle
.
index_
,
handle_index_to_typed_index_
.
size
());
return
stepwise_operands_
[
handle_index_to_typed_index_
[
handle
.
index_
]];
}
inline
MutableAlignedArea
Operands
::
GetPairwise
(
OperandHandle
handle
)
const
{
DCHECK
(
manager_
->
spec
(
handle
).
type
==
OperandType
::
kPairwise
)
<<
"Actual type: "
<<
static_cast
<
int
>
(
manager_
->
spec
(
handle
).
type
);
DCHECK_LE
(
handle
.
index_
,
handle_index_to_typed_index_
.
size
());
return
pairwise_operands_
[
handle_index_to_typed_index_
[
handle
.
index_
]];
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_OPERANDS_H_
research/syntaxnet/dragnn/runtime/operands_test.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/operands.h"
#include <string.h>
#include <tuple>
#include <utility>
#include <vector>
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/math/types.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Expects that the two pointers are the same.
void
ExpectSameAddress
(
const
void
*
pointer1
,
const
void
*
pointer2
)
{
EXPECT_EQ
(
pointer1
,
pointer2
);
}
// Sets the |vector| to |size| copies of the |value|.
template
<
class
T
>
void
Fill
(
MutableVector
<
T
>
vector
,
size_t
size
,
T
value
)
{
ASSERT_EQ
(
vector
.
size
(),
size
);
for
(
T
&
element
:
vector
)
element
=
value
;
}
// Expects that the |vector| contains |size| copies of the |expected_value|.
template
<
class
T
>
void
ExpectFilled
(
Vector
<
T
>
vector
,
size_t
size
,
T
expected_value
)
{
ASSERT_EQ
(
vector
.
size
(),
size
);
for
(
const
T
element
:
vector
)
EXPECT_EQ
(
element
,
expected_value
);
}
// Tests that OperandManager can add operands and remember their configuration.
TEST
(
OperandManagerTest
,
Add
)
{
OperandManager
manager
;
const
OperandHandle
handle1
=
manager
.
Add
({
OperandType
::
kSingular
,
7
});
const
OperandHandle
handle2
=
manager
.
Add
({
OperandType
::
kStepwise
,
11
});
EXPECT_EQ
(
manager
.
spec
(
handle1
).
type
,
OperandType
::
kSingular
);
EXPECT_EQ
(
manager
.
spec
(
handle1
).
size
,
7
);
EXPECT_EQ
(
manager
.
spec
(
handle2
).
type
,
OperandType
::
kStepwise
);
EXPECT_EQ
(
manager
.
spec
(
handle2
).
size
,
11
);
}
// Tests that Operands contains operands whose dimensions match its manager.
TEST
(
OperandsTest
,
Dimensions
)
{
const
size_t
kDim1
=
3
,
kDim2
=
41
,
kDim3
=
19
,
kDim4
=
77
;
OperandManager
manager
;
const
OperandHandle
handle1
=
manager
.
Add
({
OperandType
::
kSingular
,
kDim1
*
sizeof
(
float
)});
const
OperandHandle
handle2
=
manager
.
Add
({
OperandType
::
kStepwise
,
kDim2
*
sizeof
(
double
)});
const
OperandHandle
handle3
=
manager
.
Add
({
OperandType
::
kSingular
,
kDim3
*
sizeof
(
float
)});
const
OperandHandle
handle4
=
manager
.
Add
({
OperandType
::
kStepwise
,
kDim4
*
sizeof
(
int
)});
AlignedView
view
;
AlignedArea
area
;
Operands
operands
;
operands
.
Reset
(
&
manager
,
10
);
view
=
operands
.
GetSingular
(
handle1
);
EXPECT_EQ
(
view
.
size
(),
kDim1
*
sizeof
(
float
));
EXPECT_EQ
(
Vector
<
float
>
(
view
).
size
(),
kDim1
);
area
=
operands
.
GetStepwise
(
handle2
);
EXPECT_EQ
(
area
.
num_views
(),
0
);
// no steps yet
EXPECT_EQ
(
area
.
view_size
(),
kDim2
*
sizeof
(
double
));
EXPECT_EQ
(
Matrix
<
double
>
(
area
).
num_rows
(),
0
);
// starts with no steps
EXPECT_EQ
(
Matrix
<
double
>
(
area
).
num_columns
(),
kDim2
);
view
=
operands
.
GetSingular
(
handle3
);
EXPECT_EQ
(
view
.
size
(),
kDim3
*
sizeof
(
float
));
EXPECT_EQ
(
Vector
<
float
>
(
view
).
size
(),
kDim3
);
area
=
operands
.
GetStepwise
(
handle4
);
EXPECT_EQ
(
area
.
num_views
(),
0
);
// no steps yet
EXPECT_EQ
(
area
.
view_size
(),
kDim4
*
sizeof
(
int
));
EXPECT_EQ
(
Matrix
<
int
>
(
area
).
num_rows
(),
0
);
// starts with no steps
EXPECT_EQ
(
Matrix
<
int
>
(
area
).
num_columns
(),
kDim4
);
}
// Tests that Operands can incrementally extend stepwise operands while
// preserving existing values.
TEST
(
OperandsTest
,
AddStepToStepwise
)
{
const
size_t
kDim1
=
23
,
kDim2
=
29
;
OperandManager
manager
;
const
OperandHandle
handle1
=
manager
.
Add
({
OperandType
::
kStepwise
,
kDim1
*
sizeof
(
double
)});
const
OperandHandle
handle2
=
manager
.
Add
({
OperandType
::
kStepwise
,
kDim2
*
sizeof
(
int
)});
Operands
operands
;
operands
.
Reset
(
&
manager
,
10
);
// Repeatedly add a step and fill it with values.
for
(
int
i
=
0
;
i
<
100
;
++
i
)
{
operands
.
AddStep
();
Fill
(
MutableVector
<
double
>
(
operands
.
GetStepwise
(
handle1
).
view
(
i
)),
kDim1
,
1000.0
+
i
);
Fill
(
MutableVector
<
int
>
(
operands
.
GetStepwise
(
handle2
).
view
(
i
)),
kDim2
,
2000
+
i
);
}
// Check that data from earlier steps is preserved across reallocations.
for
(
int
i
=
0
;
i
<
100
;
++
i
)
{
ExpectFilled
(
Vector
<
double
>
(
operands
.
GetStepwise
(
handle1
).
view
(
i
)),
kDim1
,
1000.0
+
i
);
ExpectFilled
(
Vector
<
int
>
(
operands
.
GetStepwise
(
handle2
).
view
(
i
)),
kDim2
,
2000
+
i
);
}
}
// Tests that Operands can add multiple steps at once.
TEST
(
OperandsTest
,
AddStepsToStepwise
)
{
const
size_t
kDim1
=
23
,
kDim2
=
29
;
OperandManager
manager
;
const
OperandHandle
handle1
=
manager
.
Add
({
OperandType
::
kStepwise
,
kDim1
*
sizeof
(
double
)});
const
OperandHandle
handle2
=
manager
.
Add
({
OperandType
::
kStepwise
,
kDim2
*
sizeof
(
int
)});
Operands
operands
;
operands
.
Reset
(
&
manager
,
10
);
// Repeatedly add blocks of steps and fill them with values.
for
(
int
i
=
0
;
i
<
100
;
++
i
)
{
if
(
i
%
10
==
0
)
operands
.
AddSteps
(
10
);
// occasionally add a block
Fill
(
MutableVector
<
double
>
(
operands
.
GetStepwise
(
handle1
).
view
(
i
)),
kDim1
,
1000.0
+
i
);
Fill
(
MutableVector
<
int
>
(
operands
.
GetStepwise
(
handle2
).
view
(
i
)),
kDim2
,
2000
+
i
);
}
// Check that data from earlier steps is preserved across reallocations.
for
(
int
i
=
0
;
i
<
100
;
++
i
)
{
ExpectFilled
(
Vector
<
double
>
(
operands
.
GetStepwise
(
handle1
).
view
(
i
)),
kDim1
,
1000.0
+
i
);
ExpectFilled
(
Vector
<
int
>
(
operands
.
GetStepwise
(
handle2
).
view
(
i
)),
kDim2
,
2000
+
i
);
}
}
// Tests that Operands can add multiple steps to a pairwise operand.
TEST
(
OperandsTest
,
AddStepsPairwise
)
{
const
size_t
kDim1
=
4
,
kDim2
=
31
;
OperandManager
manager
;
const
OperandHandle
handle1
=
manager
.
Add
({
OperandType
::
kPairwise
,
kDim1
});
const
OperandHandle
handle2
=
manager
.
Add
({
OperandType
::
kPairwise
,
kDim2
});
Operands
operands
;
operands
.
Reset
(
&
manager
,
10
);
{
// A 1x1 pairwise operand.
operands
.
AddSteps
(
1
);
const
MutableAlignedArea
area1
=
operands
.
GetPairwise
(
handle1
);
const
MutableAlignedArea
area2
=
operands
.
GetPairwise
(
handle2
);
EXPECT_EQ
(
area1
.
num_views
(),
1
);
EXPECT_EQ
(
area2
.
num_views
(),
1
);
EXPECT_EQ
(
area1
.
view_size
(),
kDim1
);
EXPECT_EQ
(
area2
.
view_size
(),
kDim2
);
// Write to operands to test the validity of the underlying memory region.
memset
(
area1
.
view
(
0
).
data
(),
0
,
kDim1
);
memset
(
area2
.
view
(
0
).
data
(),
0
,
kDim2
);
}
{
// A 10x10 pairwise operand.
operands
.
AddSteps
(
9
);
const
MutableAlignedArea
area1
=
operands
.
GetPairwise
(
handle1
);
const
MutableAlignedArea
area2
=
operands
.
GetPairwise
(
handle2
);
EXPECT_EQ
(
area1
.
num_views
(),
10
);
EXPECT_EQ
(
area2
.
num_views
(),
10
);
EXPECT_EQ
(
area1
.
view_size
(),
10
*
kDim1
);
EXPECT_EQ
(
area2
.
view_size
(),
10
*
kDim2
);
// Infer the stride by comparing pointers between consecutive views.
const
size_t
expected_stride
=
PadToAlignment
(
10
*
kDim1
)
+
PadToAlignment
(
10
*
kDim2
);
EXPECT_EQ
(
area1
.
view
(
1
).
data
()
-
area1
.
view
(
0
).
data
(),
expected_stride
);
EXPECT_EQ
(
area2
.
view
(
1
).
data
()
-
area2
.
view
(
0
).
data
(),
expected_stride
);
// Write to operands to test the validity of the underlying memory region.
memset
(
area1
.
view
(
9
).
data
(),
0
,
10
*
kDim1
);
memset
(
area2
.
view
(
9
).
data
(),
0
,
10
*
kDim2
);
}
}
// Tests that Operands can be reused by resetting them repeatedly, possibly
// switching between different managers.
TEST
(
OperandsTest
,
ResetWithDifferentManagers
)
{
std
::
vector
<
OperandManager
>
managers
;
std
::
vector
<
std
::
tuple
<
OperandHandle
,
OperandHandle
,
OperandHandle
>>
handles
;
for
(
int
dim
=
0
;
dim
<
10
;
++
dim
)
{
managers
.
emplace_back
();
handles
.
emplace_back
(
managers
.
back
().
Add
({
OperandType
::
kSingular
,
dim
*
sizeof
(
double
)}),
managers
.
back
().
Add
({
OperandType
::
kStepwise
,
dim
*
sizeof
(
int
)}),
managers
.
back
().
Add
({
OperandType
::
kPairwise
,
dim
*
sizeof
(
float
)}));
}
Operands
operands
;
for
(
int
trial
=
0
;
trial
<
10
;
++
trial
)
{
for
(
int
dim
=
0
;
dim
<
10
;
++
dim
)
{
operands
.
Reset
(
&
managers
[
dim
],
10
);
const
OperandHandle
singular_handle
=
std
::
get
<
0
>
(
handles
[
dim
]);
const
OperandHandle
stepwise_handle
=
std
::
get
<
1
>
(
handles
[
dim
]);
const
OperandHandle
pairwise_handle
=
std
::
get
<
2
>
(
handles
[
dim
]);
// Fill the singular operand.
Fill
(
MutableVector
<
double
>
(
operands
.
GetSingular
(
singular_handle
)),
dim
,
100.0
*
trial
+
dim
);
// Check the singular operands.
ExpectFilled
(
Vector
<
double
>
(
operands
.
GetSingular
(
singular_handle
)),
dim
,
100.0
*
trial
+
dim
);
// Repeatedly add a step and fill it with values.
for
(
int
step
=
0
;
step
<
100
;
++
step
)
{
operands
.
AddStep
();
Fill
(
MutableVector
<
int
>
(
operands
.
GetStepwise
(
stepwise_handle
).
view
(
step
)),
dim
,
1000
*
trial
+
100
*
dim
+
step
);
}
// Check that data from earlier steps is preserved across reallocations.
for
(
int
step
=
0
;
step
<
100
;
++
step
)
{
ExpectFilled
(
Vector
<
int
>
(
operands
.
GetStepwise
(
stepwise_handle
).
view
(
step
)),
dim
,
1000
*
trial
+
100
*
dim
+
step
);
}
// Check the dimensions of pairwise operands.
Matrix
<
float
>
pairwise
(
operands
.
GetPairwise
(
pairwise_handle
));
EXPECT_EQ
(
pairwise
.
num_rows
(),
100
);
EXPECT_EQ
(
pairwise
.
num_columns
(),
100
*
dim
);
}
}
}
// Tests that one OperandManager can be shared simultaneously between multiple
// Operands instances.
TEST
(
OperandsTest
,
SharedManager
)
{
const
size_t
kDim
=
17
;
OperandManager
manager
;
const
OperandHandle
singular_handle
=
manager
.
Add
({
OperandType
::
kSingular
,
kDim
*
sizeof
(
double
)});
const
OperandHandle
stepwise_handle
=
manager
.
Add
({
OperandType
::
kStepwise
,
kDim
*
sizeof
(
int
)});
std
::
vector
<
Operands
>
operands_vec
(
10
);
for
(
Operands
&
operands
:
operands_vec
)
operands
.
Reset
(
&
manager
,
10
);
// Fill all singular operands.
for
(
int
trial
=
0
;
trial
<
operands_vec
.
size
();
++
trial
)
{
const
Operands
&
operands
=
operands_vec
[
trial
];
Fill
(
MutableVector
<
double
>
(
operands
.
GetSingular
(
singular_handle
)),
kDim
,
3.0
*
trial
);
}
// Check all singular operands.
for
(
int
trial
=
0
;
trial
<
operands_vec
.
size
();
++
trial
)
{
const
Operands
&
operands
=
operands_vec
[
trial
];
ExpectFilled
(
Vector
<
double
>
(
operands
.
GetSingular
(
singular_handle
)),
kDim
,
3.0
*
trial
);
}
// Fill all stepwise operands. Interleave operations on the operands on each
// step, so all operands are "active" at the same time.
for
(
int
step
=
0
;
step
<
100
;
++
step
)
{
for
(
int
trial
=
0
;
trial
<
10
;
++
trial
)
{
Operands
&
operands
=
operands_vec
[
trial
];
operands
.
AddStep
();
Fill
(
MutableVector
<
int
>
(
operands
.
GetStepwise
(
stepwise_handle
).
view
(
step
)),
kDim
,
trial
*
999
+
step
);
}
}
// Check all stepwise operands.
for
(
int
step
=
0
;
step
<
100
;
++
step
)
{
for
(
int
trial
=
0
;
trial
<
10
;
++
trial
)
{
const
Operands
&
operands
=
operands_vec
[
trial
];
ExpectFilled
(
Vector
<
int
>
(
operands
.
GetStepwise
(
stepwise_handle
).
view
(
step
)),
kDim
,
trial
*
999
+
step
);
}
}
}
// Tests that an Operands uses all of the pre-allocated steps and reallocates
// exactly when it exhausts the pre-allocated array.
TEST
(
OperandsTest
,
UsesPreAllocatedSteps
)
{
const
size_t
kBytes
=
5
;
const
size_t
kPreAllocateNumSteps
=
10
;
OperandManager
manager
;
const
OperandHandle
handle
=
manager
.
Add
({
OperandType
::
kStepwise
,
kBytes
});
Operands
operands
;
operands
.
Reset
(
&
manager
,
kPreAllocateNumSteps
);
// The first N steps fit exactly in the pre-allocated array. Access the base
// of the stepwise array via the first view.
operands
.
AddStep
();
char
*
const
pre_allocated_data
=
operands
.
GetStepwise
(
handle
).
view
(
0
).
data
();
for
(
size_t
step
=
1
;
step
<
kPreAllocateNumSteps
;
++
step
)
{
operands
.
AddStep
();
ASSERT_EQ
(
operands
.
GetStepwise
(
handle
).
view
(
0
).
data
(),
pre_allocated_data
);
}
// The N+1'st step triggers a reallocation, which is guaranteed to yield a new
// pointer because it creates a separate array and copies into it.
operands
.
AddStep
();
ASSERT_NE
(
operands
.
GetStepwise
(
handle
).
view
(
0
).
data
(),
pre_allocated_data
);
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/recurrent_sequence_linkers.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/sequence_linker.h"
#include "dragnn/runtime/transition_system_traits.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Links to the previous step in the same component. Templated on a bool that
// indicates the direction that the transition system runs in.
template
<
bool
left_to_right
>
class
RecurrentSequenceLinker
:
public
SequenceLinker
{
public:
// Implements SequenceLinker.
bool
Supports
(
const
LinkedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
)
const
override
;
tensorflow
::
Status
Initialize
(
const
LinkedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
)
override
;
tensorflow
::
Status
GetLinks
(
size_t
source_num_steps
,
InputBatchCache
*
input
,
std
::
vector
<
int32
>
*
links
)
const
override
;
};
template
<
bool
left_to_right
>
bool
RecurrentSequenceLinker
<
left_to_right
>::
Supports
(
const
LinkedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
)
const
{
TransitionSystemTraits
traits
(
component_spec
);
// Here, fml="bias" and source_translator="history" are a DRAGNN recipe for
// linking to the previous transition step. More concretely,
// * "bias" always extracts index 0.
// * "history" subtracts the index it is given from (#steps - 1).
// Putting the two together, we link to (#steps - 1 - 0); i.e., the previous
// transition step.
return
(
channel
.
fml
()
==
"bias"
||
channel
.
fml
()
==
"bias(0)"
)
&&
channel
.
source_component
()
==
component_spec
.
name
()
&&
channel
.
source_translator
()
==
"history"
&&
traits
.
is_left_to_right
==
left_to_right
&&
traits
.
is_sequential
;
}
template
<
bool
left_to_right
>
tensorflow
::
Status
RecurrentSequenceLinker
<
left_to_right
>::
Initialize
(
const
LinkedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
)
{
return
tensorflow
::
Status
::
OK
();
}
template
<
bool
left_to_right
>
tensorflow
::
Status
RecurrentSequenceLinker
<
left_to_right
>::
GetLinks
(
size_t
source_num_steps
,
InputBatchCache
*
input
,
std
::
vector
<
int32
>
*
links
)
const
{
links
->
resize
(
source_num_steps
);
if
(
left_to_right
)
{
int32
index
=
-
1
;
for
(
int32
&
link
:
*
links
)
link
=
index
++
;
}
else
{
int32
index
=
static_cast
<
int32
>
(
source_num_steps
)
-
1
;
for
(
int32
&
link
:
*
links
)
link
=
--
index
;
}
return
tensorflow
::
Status
::
OK
();
}
using
LeftToRightRecurrentSequenceLinker
=
RecurrentSequenceLinker
<
true
>
;
using
RightToLeftRecurrentSequenceLinker
=
RecurrentSequenceLinker
<
false
>
;
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER
(
LeftToRightRecurrentSequenceLinker
);
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER
(
RightToLeftRecurrentSequenceLinker
);
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/recurrent_sequence_linkers_test.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/sequence_linker.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Returns a ComponentSpec that the linker will support.
ComponentSpec
MakeSupportedSpec
()
{
ComponentSpec
component_spec
;
component_spec
.
set_name
(
"test_component"
);
component_spec
.
mutable_transition_system
()
->
set_registered_name
(
"shift-only"
);
LinkedFeatureChannel
*
channel
=
component_spec
.
add_linked_feature
();
channel
->
set_fml
(
"bias"
);
channel
->
set_source_component
(
"test_component"
);
channel
->
set_source_translator
(
"history"
);
return
component_spec
;
}
// Tests that the linker supports appropriate specs.
TEST
(
RecurrentSequenceLinkerTest
,
Supported
)
{
string
name
;
ComponentSpec
component_spec
=
MakeSupportedSpec
();
LinkedFeatureChannel
&
channel
=
*
component_spec
.
mutable_linked_feature
(
0
);
TF_ASSERT_OK
(
SequenceLinker
::
Select
(
channel
,
component_spec
,
&
name
));
EXPECT_EQ
(
name
,
"LeftToRightRecurrentSequenceLinker"
);
(
*
component_spec
.
mutable_transition_system
()
->
mutable_parameters
())[
"left_to_right"
]
=
"false"
;
TF_ASSERT_OK
(
SequenceLinker
::
Select
(
channel
,
component_spec
,
&
name
));
EXPECT_EQ
(
name
,
"RightToLeftRecurrentSequenceLinker"
);
channel
.
set_fml
(
"bias(0)"
);
TF_ASSERT_OK
(
SequenceLinker
::
Select
(
channel
,
component_spec
,
&
name
));
EXPECT_EQ
(
name
,
"RightToLeftRecurrentSequenceLinker"
);
(
*
component_spec
.
mutable_transition_system
()
->
mutable_parameters
())[
"left_to_right"
]
=
"true"
;
TF_ASSERT_OK
(
SequenceLinker
::
Select
(
channel
,
component_spec
,
&
name
));
EXPECT_EQ
(
name
,
"LeftToRightRecurrentSequenceLinker"
);
}
// Tests that the linker requires the right transition system.
TEST
(
RecurrentSequenceLinkerTest
,
WrongTransitionSystem
)
{
string
name
;
ComponentSpec
component_spec
=
MakeSupportedSpec
();
const
LinkedFeatureChannel
&
channel
=
component_spec
.
linked_feature
(
0
);
component_spec
.
mutable_transition_system
()
->
set_registered_name
(
"bad"
);
EXPECT_THAT
(
SequenceLinker
::
Select
(
channel
,
component_spec
,
&
name
),
test
::
IsErrorWithSubstr
(
"No SequenceLinker supports channel"
));
}
// Tests that the linker requires the right FML.
TEST
(
RecurrentSequenceLinkerTest
,
WrongFml
)
{
string
name
;
ComponentSpec
component_spec
=
MakeSupportedSpec
();
LinkedFeatureChannel
&
channel
=
*
component_spec
.
mutable_linked_feature
(
0
);
channel
.
set_fml
(
"bad"
);
EXPECT_THAT
(
SequenceLinker
::
Select
(
channel
,
component_spec
,
&
name
),
test
::
IsErrorWithSubstr
(
"No SequenceLinker supports channel"
));
}
// Tests that the linker requires a recurrent link.
TEST
(
RecurrentSequenceLinkerTest
,
WrongSource
)
{
string
name
;
ComponentSpec
component_spec
=
MakeSupportedSpec
();
LinkedFeatureChannel
&
channel
=
*
component_spec
.
mutable_linked_feature
(
0
);
channel
.
set_source_component
(
"bad"
);
EXPECT_THAT
(
SequenceLinker
::
Select
(
channel
,
component_spec
,
&
name
),
test
::
IsErrorWithSubstr
(
"No SequenceLinker supports channel"
));
}
// Tests that the linker requires the right translator.
TEST
(
RecurrentSequenceLinkerTest
,
WrongTranslator
)
{
string
name
;
ComponentSpec
component_spec
=
MakeSupportedSpec
();
LinkedFeatureChannel
&
channel
=
*
component_spec
.
mutable_linked_feature
(
0
);
channel
.
set_source_translator
(
"bad"
);
EXPECT_THAT
(
SequenceLinker
::
Select
(
channel
,
component_spec
,
&
name
),
test
::
IsErrorWithSubstr
(
"No SequenceLinker supports channel"
));
}
// Tests that the linker can be initialized and used to extract links.
TEST
(
RecurrentSequenceLinkerTest
,
InitializeAndGetLinks
)
{
const
ComponentSpec
component_spec
=
MakeSupportedSpec
();
const
LinkedFeatureChannel
&
channel
=
component_spec
.
linked_feature
(
0
);
std
::
unique_ptr
<
SequenceLinker
>
linker
;
TF_ASSERT_OK
(
SequenceLinker
::
New
(
"LeftToRightRecurrentSequenceLinker"
,
channel
,
component_spec
,
&
linker
));
InputBatchCache
input
;
std
::
vector
<
int32
>
links
=
{
123
,
456
,
789
};
// gets overwritten
TF_ASSERT_OK
(
linker
->
GetLinks
(
10
,
&
input
,
&
links
));
const
std
::
vector
<
int32
>
expected_links
=
{
-
1
,
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
};
EXPECT_EQ
(
links
,
expected_links
);
}
// Tests that the links are reversed for right-to-left components.
TEST
(
RecurrentSequenceLinkerTest
,
InitializeAndGetLinksRightToLeft
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
const
LinkedFeatureChannel
&
channel
=
component_spec
.
linked_feature
(
0
);
std
::
unique_ptr
<
SequenceLinker
>
linker
;
TF_ASSERT_OK
(
SequenceLinker
::
New
(
"RightToLeftRecurrentSequenceLinker"
,
channel
,
component_spec
,
&
linker
));
InputBatchCache
input
;
std
::
vector
<
int32
>
links
=
{
123
,
456
,
789
};
// gets overwritten
TF_ASSERT_OK
(
linker
->
GetLinks
(
10
,
&
input
,
&
links
));
const
std
::
vector
<
int32
>
expected_links
=
{
8
,
7
,
6
,
5
,
4
,
3
,
2
,
1
,
0
,
-
1
};
EXPECT_EQ
(
links
,
expected_links
);
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/reversed_sequence_linker.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/sequence_linker.h"
#include "dragnn/runtime/transition_system_traits.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Applies a reversed identity function.
class
ReversedSequenceLinker
:
public
SequenceLinker
{
public:
// Implements SequenceLinker.
bool
Supports
(
const
LinkedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
)
const
override
;
tensorflow
::
Status
Initialize
(
const
LinkedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
)
override
;
tensorflow
::
Status
GetLinks
(
size_t
source_num_steps
,
InputBatchCache
*
input
,
std
::
vector
<
int32
>
*
links
)
const
override
;
};
bool
ReversedSequenceLinker
::
Supports
(
const
LinkedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
)
const
{
TransitionSystemTraits
traits
(
component_spec
);
// Note: Add more "||" clauses as needed.
return
((
channel
.
fml
()
==
"input.focus"
&&
channel
.
source_translator
()
==
"reverse-token"
)
||
(
channel
.
fml
()
==
"char-input.focus"
&&
channel
.
source_translator
()
==
"reverse-char"
))
&&
traits
.
is_sequential
;
}
tensorflow
::
Status
ReversedSequenceLinker
::
Initialize
(
const
LinkedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
)
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
ReversedSequenceLinker
::
GetLinks
(
size_t
source_num_steps
,
InputBatchCache
*
input
,
std
::
vector
<
int32
>
*
links
)
const
{
links
->
resize
(
source_num_steps
);
int32
index
=
links
->
size
();
for
(
int32
&
link
:
*
links
)
link
=
--
index
;
return
tensorflow
::
Status
::
OK
();
}
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER
(
ReversedSequenceLinker
);
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/reversed_sequence_linker_test.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/sequence_linker.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Returns a ComponentSpec that the linker will support.
ComponentSpec
MakeSupportedSpec
()
{
ComponentSpec
component_spec
;
component_spec
.
mutable_transition_system
()
->
set_registered_name
(
"shift-only"
);
LinkedFeatureChannel
*
channel
=
component_spec
.
add_linked_feature
();
channel
->
set_fml
(
"input.focus"
);
channel
->
set_source_translator
(
"reverse-token"
);
return
component_spec
;
}
// Tests that the linker supports appropriate specs.
TEST
(
ReversedSequenceLinkerTest
,
Supported
)
{
string
name
;
ComponentSpec
component_spec
=
MakeSupportedSpec
();
LinkedFeatureChannel
&
channel
=
*
component_spec
.
mutable_linked_feature
(
0
);
TF_ASSERT_OK
(
SequenceLinker
::
Select
(
channel
,
component_spec
,
&
name
));
EXPECT_EQ
(
name
,
"ReversedSequenceLinker"
);
channel
.
set_fml
(
"char-input.focus"
);
channel
.
set_source_translator
(
"reverse-char"
);
TF_ASSERT_OK
(
SequenceLinker
::
Select
(
channel
,
component_spec
,
&
name
));
EXPECT_EQ
(
name
,
"ReversedSequenceLinker"
);
}
// Tests that the linker requires the right transition system.
TEST
(
IdentitySequenceLinkerTest
,
WrongTransitionSystem
)
{
string
name
;
ComponentSpec
component_spec
=
MakeSupportedSpec
();
const
LinkedFeatureChannel
&
channel
=
component_spec
.
linked_feature
(
0
);
component_spec
.
mutable_transition_system
()
->
set_registered_name
(
"bad"
);
EXPECT_THAT
(
SequenceLinker
::
Select
(
channel
,
component_spec
,
&
name
),
test
::
IsErrorWithSubstr
(
"No SequenceLinker supports channel"
));
}
// Tests that the linker requires the right FML.
TEST
(
ReversedSequenceLinkerTest
,
WrongFml
)
{
string
name
;
ComponentSpec
component_spec
=
MakeSupportedSpec
();
LinkedFeatureChannel
&
channel
=
*
component_spec
.
mutable_linked_feature
(
0
);
channel
.
set_fml
(
"bad"
);
EXPECT_THAT
(
SequenceLinker
::
Select
(
channel
,
component_spec
,
&
name
),
test
::
IsErrorWithSubstr
(
"No SequenceLinker supports channel"
));
}
// Tests that the linker requires the right translator.
TEST
(
ReversedSequenceLinkerTest
,
WrongTranslator
)
{
string
name
;
ComponentSpec
component_spec
=
MakeSupportedSpec
();
LinkedFeatureChannel
&
channel
=
*
component_spec
.
mutable_linked_feature
(
0
);
channel
.
set_source_translator
(
"bad"
);
EXPECT_THAT
(
SequenceLinker
::
Select
(
channel
,
component_spec
,
&
name
),
test
::
IsErrorWithSubstr
(
"No SequenceLinker supports channel"
));
}
// Tests that the linker requires the right combination of FML and translator.
TEST
(
ReversedSequenceLinkerTest
,
MismatchedFmlAndTranslator
)
{
string
name
;
ComponentSpec
component_spec
=
MakeSupportedSpec
();
LinkedFeatureChannel
&
channel
=
*
component_spec
.
mutable_linked_feature
(
0
);
channel
.
set_fml
(
"input.focus"
);
channel
.
set_source_translator
(
"reverse-char"
);
EXPECT_THAT
(
SequenceLinker
::
Select
(
channel
,
component_spec
,
&
name
),
test
::
IsErrorWithSubstr
(
"No SequenceLinker supports channel"
));
channel
.
set_fml
(
"char-input.focus"
);
channel
.
set_source_translator
(
"reverse-token"
);
EXPECT_THAT
(
SequenceLinker
::
Select
(
channel
,
component_spec
,
&
name
),
test
::
IsErrorWithSubstr
(
"No SequenceLinker supports channel"
));
}
// Tests that the linker can be initialized and used to extract links.
TEST
(
ReversedSequenceLinkerTest
,
InitializeAndGetLinks
)
{
const
ComponentSpec
component_spec
=
MakeSupportedSpec
();
const
LinkedFeatureChannel
&
channel
=
component_spec
.
linked_feature
(
0
);
std
::
unique_ptr
<
SequenceLinker
>
linker
;
TF_ASSERT_OK
(
SequenceLinker
::
New
(
"ReversedSequenceLinker"
,
channel
,
component_spec
,
&
linker
));
InputBatchCache
input
;
std
::
vector
<
int32
>
links
=
{
123
,
456
,
789
};
// gets overwritten
TF_ASSERT_OK
(
linker
->
GetLinks
(
10
,
&
input
,
&
links
));
const
std
::
vector
<
int32
>
expected_links
=
{
9
,
8
,
7
,
6
,
5
,
4
,
3
,
2
,
1
,
0
};
EXPECT_EQ
(
links
,
expected_links
);
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/select_best_component_transformer.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <string>
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/component_transformation.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Transformer that selects the best component subclass for the ComponentSpec.
class
SelectBestComponentTransformer
:
public
ComponentTransformer
{
public:
// Implements ComponentTransformer.
tensorflow
::
Status
Transform
(
const
string
&
component_type
,
ComponentSpec
*
component_spec
)
override
{
string
best_component_type
;
TF_RETURN_IF_ERROR
(
Component
::
Select
(
*
component_spec
,
&
best_component_type
));
component_spec
->
mutable_component_builder
()
->
set_registered_name
(
best_component_type
);
if
(
component_type
!=
best_component_type
)
{
LOG
(
INFO
)
<<
"Component '"
<<
component_spec
->
name
()
<<
"' builder updated from "
<<
component_type
<<
" to "
<<
best_component_type
<<
"."
;
}
else
{
VLOG
(
2
)
<<
"Component '"
<<
component_spec
->
name
()
<<
"' builder type "
<<
component_type
<<
" unchanged."
;
}
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_COMPONENT_TRANSFORMER
(
SelectBestComponentTransformer
);
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/select_best_component_transformer_test.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <memory>
#include <string>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/component_transformation.h"
#include "dragnn/runtime/extensions.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Base class for test components.
class
TestComponentBase
:
public
Component
{
public:
// Partially implements Component.
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
,
VariableStore
*
,
NetworkStateManager
*
,
ExtensionManager
*
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
Evaluate
(
SessionState
*
,
ComputeSession
*
,
ComponentTrace
*
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
bool
PreferredTo
(
const
Component
&
)
const
override
{
return
false
;
}
};
// Supports components whose builder name includes "Foo".
class
ContainsFoo
:
public
TestComponentBase
{
public:
// Implements Component.
bool
Supports
(
const
ComponentSpec
&
,
const
string
&
normalized_builder_name
)
const
override
{
return
normalized_builder_name
.
find
(
"Foo"
)
!=
string
::
npos
;
}
};
DRAGNN_RUNTIME_REGISTER_COMPONENT
(
ContainsFoo
);
// Supports components whose builder name includes "Bar".
class
ContainsBar
:
public
TestComponentBase
{
public:
// Implements Component.
bool
Supports
(
const
ComponentSpec
&
,
const
string
&
normalized_builder_name
)
const
override
{
return
normalized_builder_name
.
find
(
"Bar"
)
!=
string
::
npos
;
}
};
DRAGNN_RUNTIME_REGISTER_COMPONENT
(
ContainsBar
);
// Tests that a spec with an unknown builder name causes an error.
TEST
(
SelectBestComponentTransformerTest
,
Unknown
)
{
ComponentSpec
component_spec
;
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"unknown"
);
EXPECT_THAT
(
ComponentTransformer
::
ApplyAll
(
&
component_spec
),
test
::
IsErrorWithSubstr
(
"Could not find a best"
));
}
// Tests that a spec with builder "Foo" is changed to "ContainsFoo".
TEST
(
SelectBestComponentTransformerTest
,
ChangeToContainsFoo
)
{
ComponentSpec
component_spec
;
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"Foo"
);
ComponentSpec
expected_spec
=
component_spec
;
expected_spec
.
mutable_component_builder
()
->
set_registered_name
(
"ContainsFoo"
);
TF_ASSERT_OK
(
ComponentTransformer
::
ApplyAll
(
&
component_spec
));
EXPECT_THAT
(
component_spec
,
test
::
EqualsProto
(
expected_spec
));
}
// Tests that a spec with builder "Bar" is changed to "ContainsBar".
TEST
(
SelectBestComponentTransformerTest
,
ChangeToContainsBar
)
{
ComponentSpec
component_spec
;
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"Bar"
);
ComponentSpec
expected_spec
=
component_spec
;
expected_spec
.
mutable_component_builder
()
->
set_registered_name
(
"ContainsBar"
);
TF_ASSERT_OK
(
ComponentTransformer
::
ApplyAll
(
&
component_spec
));
EXPECT_THAT
(
component_spec
,
test
::
EqualsProto
(
expected_spec
));
}
// Tests that a spec with builder "FooBar" causes a conflict.
TEST
(
SelectBestComponentTransformerTest
,
Conflict
)
{
ComponentSpec
component_spec
;
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"FooBar"
);
EXPECT_THAT
(
ComponentTransformer
::
ApplyAll
(
&
component_spec
),
test
::
IsErrorWithSubstr
(
"both think they should be dis-preferred"
));
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/sequence_backend.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_backend.h"
#include "dragnn/core/component_registry.h"
#include "tensorflow/core/platform/logging.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
std
::
function
<
int
(
int
,
int
,
int
)
>
SequenceBackend
::
GetStepLookupFunction
(
const
string
&
method
)
{
if
(
method
==
"reverse-char"
||
method
==
"reverse-token"
)
{
// Reverses the |index| in the sequence. We are agnostic to whether the
// input is a sequence of tokens or chars.
return
[
this
](
int
unused_batch_index
,
int
unused_beam_index
,
int
index
)
{
index
=
sequence_size_
-
index
-
1
;
return
index
>=
0
&&
index
<
sequence_size_
?
index
:
-
1
;
};
}
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Unknown step lookup function: "
<<
method
;
}
void
SequenceBackend
::
InitializeComponent
(
const
ComponentSpec
&
spec
)
{
name_
=
spec
.
name
();
}
void
SequenceBackend
::
InitializeData
(
const
std
::
vector
<
std
::
vector
<
const
TransitionState
*>>
&
parent_states
,
int
max_beam_size
,
InputBatchCache
*
input_data
)
{
// Store the |parent_states| for forwarding to downstream components.
parent_states_
=
parent_states
;
}
std
::
vector
<
std
::
vector
<
const
TransitionState
*>>
SequenceBackend
::
GetBeam
()
{
// Forward the states of the previous component.
return
parent_states_
;
}
int
SequenceBackend
::
GetSourceBeamIndex
(
int
current_index
,
int
batch
)
const
{
// Forward the |current_index| to the previous component.
return
current_index
;
}
int
SequenceBackend
::
GetBeamIndexAtStep
(
int
step
,
int
current_index
,
int
batch
)
const
{
// Always return 0 since there is only one beam.
return
0
;
}
std
::
vector
<
std
::
vector
<
ComponentTrace
>>
SequenceBackend
::
GetTraceProtos
()
const
{
// Return a single trace, since the beam and batch sizes are fixed at 1.
return
{{
ComponentTrace
()}};
}
string
SequenceBackend
::
Name
()
const
{
return
name_
;
}
int
SequenceBackend
::
BeamSize
()
const
{
return
1
;
}
int
SequenceBackend
::
BatchSize
()
const
{
return
1
;
}
bool
SequenceBackend
::
IsReady
()
const
{
return
true
;
}
bool
SequenceBackend
::
IsTerminal
()
const
{
return
true
;
}
void
SequenceBackend
::
FinalizeData
()
{}
void
SequenceBackend
::
ResetComponent
()
{}
void
SequenceBackend
::
InitializeTracing
()
{}
void
SequenceBackend
::
DisableTracing
()
{}
int
SequenceBackend
::
StepsTaken
(
int
batch_index
)
const
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Not supported"
;
}
bool
SequenceBackend
::
AdvanceFromPrediction
(
const
float
*
transition_matrix
,
int
num_items
,
int
num_actions
)
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Not supported"
;
}
void
SequenceBackend
::
AdvanceFromOracle
()
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Not supported"
;
}
std
::
vector
<
std
::
vector
<
std
::
vector
<
Label
>>>
SequenceBackend
::
GetOracleLabels
()
const
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Not supported"
;
}
int
SequenceBackend
::
GetFixedFeatures
(
std
::
function
<
int32
*
(
int
)
>
allocate_indices
,
std
::
function
<
int64
*
(
int
)
>
allocate_ids
,
std
::
function
<
float
*
(
int
)
>
allocate_weights
,
int
channel_id
)
const
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Not supported"
;
}
int
SequenceBackend
::
BulkGetFixedFeatures
(
const
BulkFeatureExtractor
&
extractor
)
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Not supported"
;
}
void
SequenceBackend
::
BulkEmbedFixedFeatures
(
int
batch_size_padding
,
int
num_steps_padding
,
int
output_array_size
,
const
vector
<
const
float
*>
&
per_channel_embeddings
,
float
*
embedding_output
)
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Not supported"
;
}
void
SequenceBackend
::
BulkEmbedDenseFixedFeatures
(
const
vector
<
const
float
*>
&
per_channel_embeddings
,
float
*
embedding_output
,
int
embedding_output_size
,
int
*
offset_array_output
,
int
offset_array_size
)
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Not supported"
;
}
int
SequenceBackend
::
BulkDenseFeatureSize
()
const
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Not supported"
;
}
std
::
vector
<
LinkFeatures
>
SequenceBackend
::
GetRawLinkFeatures
(
int
channel_id
)
const
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Not supported"
;
}
void
SequenceBackend
::
AddTranslatedLinkFeaturesToTrace
(
const
std
::
vector
<
LinkFeatures
>
&
features
,
int
channel_id
)
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Not supported"
;
}
REGISTER_DRAGNN_COMPONENT
(
SequenceBackend
);
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/sequence_backend.h
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_SEQUENCE_BACKEND_H_
#define DRAGNN_RUNTIME_SEQUENCE_BACKEND_H_
#include <functional>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/interfaces/component.h"
#include "dragnn/core/interfaces/transition_state.h"
#include "dragnn/core/util/label.h"
#include "dragnn/protos/data.pb.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "syntaxnet/base.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// Runtime-only component backend for sequence-based models. This is not used
// at training time, and provides trivial implementations of most methods. This
// is intended to be used with alternative feature extraction approaches, such
// as SequenceExtractor.
class
SequenceBackend
:
public
dragnn
::
Component
{
public:
// Sets the size of the sequence in the current input.
void
SetSequenceSize
(
int
size
)
{
sequence_size_
=
size
;
}
// Implements dragnn::Component.
std
::
function
<
int
(
int
,
int
,
int
)
>
GetStepLookupFunction
(
const
string
&
method
)
override
;
void
InitializeComponent
(
const
ComponentSpec
&
spec
)
override
;
void
InitializeData
(
const
std
::
vector
<
std
::
vector
<
const
TransitionState
*>>
&
parent_states
,
int
max_beam_size
,
InputBatchCache
*
input_data
)
override
;
std
::
vector
<
std
::
vector
<
const
TransitionState
*>>
GetBeam
()
override
;
int
GetSourceBeamIndex
(
int
current_index
,
int
batch
)
const
override
;
int
GetBeamIndexAtStep
(
int
step
,
int
current_index
,
int
batch
)
const
override
;
std
::
vector
<
std
::
vector
<
ComponentTrace
>>
GetTraceProtos
()
const
override
;
string
Name
()
const
override
;
int
BeamSize
()
const
override
;
int
BatchSize
()
const
override
;
bool
IsReady
()
const
override
;
bool
IsTerminal
()
const
override
;
void
FinalizeData
()
override
;
void
ResetComponent
()
override
;
void
InitializeTracing
()
override
;
void
DisableTracing
()
override
;
// Not implemented, crashes when called.
int
StepsTaken
(
int
batch_index
)
const
override
;
// Not implemented, crashes when called.
bool
AdvanceFromPrediction
(
const
float
*
transition_matrix
,
int
num_items
,
int
num_actions
)
override
;
// Not implemented, crashes when called.
void
AdvanceFromOracle
()
override
;
// Not implemented, crashes when called.
std
::
vector
<
std
::
vector
<
std
::
vector
<
Label
>>>
GetOracleLabels
()
const
override
;
// Not implemented, crashes when called.
int
GetFixedFeatures
(
std
::
function
<
int32
*
(
int
)
>
allocate_indices
,
std
::
function
<
int64
*
(
int
)
>
allocate_ids
,
std
::
function
<
float
*
(
int
)
>
allocate_weights
,
int
channel_id
)
const
override
;
// Not implemented, crashes when called.
int
BulkGetFixedFeatures
(
const
BulkFeatureExtractor
&
extractor
)
override
;
// Not implemented, crashes when called.
void
BulkEmbedFixedFeatures
(
int
batch_size_padding
,
int
num_steps_padding
,
int
output_array_size
,
const
vector
<
const
float
*>
&
per_channel_embeddings
,
float
*
embedding_output
)
override
;
// Not implemented, crashes when called.
void
BulkEmbedDenseFixedFeatures
(
const
vector
<
const
float
*>
&
per_channel_embeddings
,
float
*
embedding_output
,
int
embedding_output_size
,
int
*
offset_array_output
,
int
offset_array_size
)
override
;
// Not implemented, crashes when called.
int
BulkDenseFeatureSize
()
const
override
;
// Not implemented, crashes when called.
std
::
vector
<
LinkFeatures
>
GetRawLinkFeatures
(
int
channel_id
)
const
override
;
// Not implemented, crashes when called.
void
AddTranslatedLinkFeaturesToTrace
(
const
std
::
vector
<
LinkFeatures
>
&
features
,
int
channel_id
)
override
;
private:
// Name of the component that this backend supports.
string
name_
;
// Size of the current input sequence.
int
sequence_size_
=
0
;
// Parent states passed to InitializeData(), and passed along in GetBeam().
std
::
vector
<
std
::
vector
<
const
TransitionState
*>>
parent_states_
;
};
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_SEQUENCE_BACKEND_H_
research/syntaxnet/dragnn/runtime/sequence_backend_test.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_backend.h"
#include "dragnn/components/util/bulk_feature_extractor.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/interfaces/transition_state.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Tests that the "reverse-*" step lookup functions ignore the batch and beam
// indices and return -1 if the sequence size was never set.
TEST
(
SequenceBackendTest
,
ReverseCharUninitialized
)
{
for
(
const
string
&
reverse_method
:
{
"reverse-char"
,
"reverse-token"
})
{
SequenceBackend
backend
;
const
std
::
function
<
int
(
int
,
int
,
int
)
>
reverse
=
backend
.
GetStepLookupFunction
(
reverse_method
);
EXPECT_EQ
(
reverse
(
0
,
0
,
0
),
-
1
);
EXPECT_EQ
(
reverse
(
1
,
1
,
1
),
-
1
);
EXPECT_EQ
(
reverse
(
-
1
,
-
1
,
-
1
),
-
1
);
EXPECT_EQ
(
reverse
(
0
,
0
,
9999
),
-
1
);
EXPECT_EQ
(
reverse
(
0
,
0
,
-
9999
),
-
1
);
}
}
// Tests that the "reverse-*" step lookup functions ignore the batch and beam
// indices and return the reverse of the step index w.r.t. the most recent call
// to SetSequenceSize().
TEST
(
SequenceBackendTest
,
ReverseCharAfterSetSequenceSize
)
{
for
(
const
string
&
reverse_method
:
{
"reverse-char"
,
"reverse-token"
})
{
SequenceBackend
backend
;
const
std
::
function
<
int
(
int
,
int
,
int
)
>
reverse
=
backend
.
GetStepLookupFunction
(
reverse_method
);
backend
.
SetSequenceSize
(
10
);
EXPECT_EQ
(
reverse
(
0
,
0
,
-
1
),
-
1
);
EXPECT_EQ
(
reverse
(
0
,
0
,
0
),
9
);
EXPECT_EQ
(
reverse
(
1
,
1
,
1
),
8
);
EXPECT_EQ
(
reverse
(
8
,
8
,
8
),
1
);
EXPECT_EQ
(
reverse
(
9
,
9
,
9
),
0
);
EXPECT_EQ
(
reverse
(
10
,
10
,
10
),
-
1
);
EXPECT_EQ
(
reverse
(
-
1
,
-
1
,
5
),
4
);
EXPECT_EQ
(
reverse
(
0
,
0
,
9999
),
-
1
);
EXPECT_EQ
(
reverse
(
0
,
0
,
-
9999
),
-
1
);
backend
.
SetSequenceSize
(
11
);
EXPECT_EQ
(
reverse
(
0
,
0
,
-
1
),
-
1
);
EXPECT_EQ
(
reverse
(
0
,
0
,
0
),
10
);
EXPECT_EQ
(
reverse
(
1
,
1
,
1
),
9
);
EXPECT_EQ
(
reverse
(
8
,
8
,
8
),
2
);
EXPECT_EQ
(
reverse
(
9
,
9
,
9
),
1
);
EXPECT_EQ
(
reverse
(
10
,
10
,
10
),
0
);
EXPECT_EQ
(
reverse
(
-
1
,
-
1
,
5
),
5
);
EXPECT_EQ
(
reverse
(
0
,
0
,
9999
),
-
1
);
EXPECT_EQ
(
reverse
(
0
,
0
,
-
9999
),
-
1
);
}
}
// Tests that the input beam is forwarded.
TEST
(
SequenceBackendTest
,
BeamForwarding
)
{
SequenceBackend
backend
;
const
TransitionState
*
parent_state
=
nullptr
;
parent_state
+=
1234
;
// arbitrary non-null pointer
const
std
::
vector
<
std
::
vector
<
const
TransitionState
*>>
parent_states
=
{
{
parent_state
}};
const
int
ignored_max_beam_size
=
999
;
InputBatchCache
*
ignored_input
=
nullptr
;
backend
.
InitializeData
(
parent_states
,
ignored_max_beam_size
,
ignored_input
);
EXPECT_EQ
(
backend
.
GetBeam
(),
parent_states
);
}
// Tests the accessors of the backend.
TEST
(
SequenceBackendTest
,
Accessors
)
{
SequenceBackend
backend
;
ComponentSpec
spec
;
spec
.
set_name
(
"foo"
);
backend
.
InitializeComponent
(
spec
);
EXPECT_EQ
(
backend
.
Name
(),
"foo"
);
EXPECT_EQ
(
backend
.
BeamSize
(),
1
);
EXPECT_EQ
(
backend
.
BatchSize
(),
1
);
EXPECT_TRUE
(
backend
.
IsReady
());
EXPECT_TRUE
(
backend
.
IsTerminal
());
}
// Tests the trivial mutators of the backend.
TEST
(
SequenceBackendTest
,
Mutators
)
{
SequenceBackend
backend
;
// These are NOPs and should not crash.
backend
.
FinalizeData
();
backend
.
ResetComponent
();
backend
.
InitializeTracing
();
backend
.
DisableTracing
();
}
// Tests the beam index accessors of the backend.
TEST
(
SequenceBackendTest
,
BeamIndex
)
{
SequenceBackend
backend
;
// This always returns the current_index (first arg).
EXPECT_EQ
(
backend
.
GetSourceBeamIndex
(
0
,
0
),
0
);
EXPECT_EQ
(
backend
.
GetSourceBeamIndex
(
1
,
2
),
1
);
EXPECT_EQ
(
backend
.
GetSourceBeamIndex
(
-
1
,
-
1
),
-
1
);
EXPECT_EQ
(
backend
.
GetSourceBeamIndex
(
10
,
99
),
10
);
// This always returns 0.
EXPECT_EQ
(
backend
.
GetBeamIndexAtStep
(
0
,
0
,
0
),
0
);
EXPECT_EQ
(
backend
.
GetBeamIndexAtStep
(
1
,
2
,
3
),
0
);
EXPECT_EQ
(
backend
.
GetBeamIndexAtStep
(
-
1
,
-
1
,
-
1
),
0
);
EXPECT_EQ
(
backend
.
GetBeamIndexAtStep
(
123
,
456
,
789
),
0
);
}
// Tests the that the backend produces a single empty trace.
TEST
(
SequenceBackendTest
,
Tracing
)
{
SequenceBackend
backend
;
const
ComponentTrace
empty_trace
;
const
auto
actual_traces
=
backend
.
GetTraceProtos
();
ASSERT_EQ
(
actual_traces
.
size
(),
1
);
ASSERT_EQ
(
actual_traces
[
0
].
size
(),
1
);
EXPECT_THAT
(
actual_traces
[
0
][
0
],
test
::
EqualsProto
(
empty_trace
));
}
// Tests the unsupported methods of the backend.
TEST
(
SequenceBackendTest
,
UnsupportedMethods
)
{
SequenceBackend
backend
;
EXPECT_DEATH
(
backend
.
StepsTaken
(
0
),
"Not supported"
);
EXPECT_DEATH
(
backend
.
AdvanceFromPrediction
(
nullptr
,
0
,
0
),
"Not supported"
);
EXPECT_DEATH
(
backend
.
AdvanceFromOracle
(),
"Not supported"
);
EXPECT_DEATH
(
backend
.
GetOracleLabels
(),
"Not supported"
);
EXPECT_DEATH
(
backend
.
GetFixedFeatures
(
nullptr
,
nullptr
,
nullptr
,
0
),
"Not supported"
);
EXPECT_DEATH
(
backend
.
BulkGetFixedFeatures
(
BulkFeatureExtractor
(
nullptr
,
nullptr
,
nullptr
)),
"Not supported"
);
EXPECT_DEATH
(
backend
.
BulkEmbedFixedFeatures
(
0
,
0
,
0
,
{},
nullptr
),
"Not supported"
);
EXPECT_DEATH
(
backend
.
GetRawLinkFeatures
(
0
),
"Not supported"
);
EXPECT_DEATH
(
backend
.
AddTranslatedLinkFeaturesToTrace
({},
0
),
"Not supported"
);
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/sequence_bulk_dynamic_component.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <string.h>
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "dragnn/runtime/bulk_network_unit.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/fixed_embeddings.h"
#include "dragnn/runtime/linked_embeddings.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/sequence_model.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/variable_store.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Sequence-based bulk version of DynamicComponent.
class
SequenceBulkDynamicComponent
:
public
Component
{
public:
// Implements Component.
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
override
;
tensorflow
::
Status
Evaluate
(
SessionState
*
session_state
,
ComputeSession
*
compute_session
,
ComponentTrace
*
component_trace
)
const
override
;
bool
Supports
(
const
ComponentSpec
&
component_spec
,
const
string
&
normalized_builder_name
)
const
override
;
bool
PreferredTo
(
const
Component
&
other
)
const
override
{
return
false
;
}
private:
// Evaluates all input features in the |state|, concatenates them into a
// matrix of inputs in the |network_states|, and returns the matrix.
Matrix
<
float
>
EvaluateInputs
(
const
SequenceModel
::
EvaluateState
&
state
,
const
NetworkStates
&
network_states
)
const
;
// Managers for input embeddings.
FixedEmbeddingManager
fixed_embedding_manager_
;
LinkedEmbeddingManager
linked_embedding_manager_
;
// Sequence-based model evaluator.
SequenceModel
sequence_model_
;
// Network unit for bulk inference.
std
::
unique_ptr
<
BulkNetworkUnit
>
bulk_network_unit_
;
// Concatenated input matrix.
LocalMatrixHandle
<
float
>
inputs_handle_
;
// Intermediate values used by sequence models.
SharedExtensionHandle
<
SequenceModel
::
EvaluateState
>
evaluate_state_handle_
;
};
bool
SequenceBulkDynamicComponent
::
Supports
(
const
ComponentSpec
&
component_spec
,
const
string
&
normalized_builder_name
)
const
{
// Require embedded fixed features.
for
(
const
FixedFeatureChannel
&
channel
:
component_spec
.
fixed_feature
())
{
if
(
channel
.
embedding_dim
()
<
0
)
return
false
;
}
// Require non-transformed and non-recurrent linked features.
// TODO(googleuser): Make SequenceLinks support transformed linked features?
for
(
const
LinkedFeatureChannel
&
channel
:
component_spec
.
linked_feature
())
{
if
(
channel
.
embedding_dim
()
>=
0
)
return
false
;
if
(
channel
.
source_component
()
==
component_spec
.
name
())
return
false
;
}
return
normalized_builder_name
==
"SequenceBulkDynamicComponent"
&&
SequenceModel
::
Supports
(
component_spec
);
}
// Returns the sum of the dimensions of all channels in the |manager|.
template
<
class
EmbeddingManager
>
size_t
SumEmbeddingDimensions
(
const
EmbeddingManager
&
manager
)
{
size_t
sum
=
0
;
for
(
size_t
i
=
0
;
i
<
manager
.
num_channels
();
++
i
)
{
sum
+=
manager
.
embedding_dim
(
i
);
}
return
sum
;
}
tensorflow
::
Status
SequenceBulkDynamicComponent
::
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
{
TF_RETURN_IF_ERROR
(
BulkNetworkUnit
::
CreateOrError
(
BulkNetworkUnit
::
GetClassName
(
component_spec
),
&
bulk_network_unit_
));
TF_RETURN_IF_ERROR
(
bulk_network_unit_
->
Initialize
(
component_spec
,
variable_store
,
network_state_manager
,
extension_manager
));
TF_RETURN_IF_ERROR
(
fixed_embedding_manager_
.
Reset
(
component_spec
,
variable_store
,
network_state_manager
));
TF_RETURN_IF_ERROR
(
linked_embedding_manager_
.
Reset
(
component_spec
,
variable_store
,
network_state_manager
));
const
size_t
concatenated_input_dim
=
SumEmbeddingDimensions
(
fixed_embedding_manager_
)
+
SumEmbeddingDimensions
(
linked_embedding_manager_
);
TF_RETURN_IF_ERROR
(
bulk_network_unit_
->
ValidateInputDimension
(
concatenated_input_dim
));
TF_RETURN_IF_ERROR
(
network_state_manager
->
AddLocal
(
concatenated_input_dim
,
&
inputs_handle_
));
TF_RETURN_IF_ERROR
(
sequence_model_
.
Initialize
(
component_spec
,
bulk_network_unit_
->
GetLogitsName
(),
&
fixed_embedding_manager_
,
&
linked_embedding_manager_
,
network_state_manager
));
extension_manager
->
GetShared
(
&
evaluate_state_handle_
);
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
SequenceBulkDynamicComponent
::
Evaluate
(
SessionState
*
session_state
,
ComputeSession
*
compute_session
,
ComponentTrace
*
component_trace
)
const
{
const
NetworkStates
&
network_states
=
session_state
->
network_states
;
SequenceModel
::
EvaluateState
&
state
=
session_state
->
extensions
.
Get
(
evaluate_state_handle_
);
TF_RETURN_IF_ERROR
(
sequence_model_
.
Preprocess
(
session_state
,
compute_session
,
&
state
));
const
Matrix
<
float
>
inputs
=
EvaluateInputs
(
state
,
network_states
);
TF_RETURN_IF_ERROR
(
bulk_network_unit_
->
Evaluate
(
inputs
,
session_state
));
return
sequence_model_
.
Predict
(
network_states
,
&
state
);
}
Matrix
<
float
>
SequenceBulkDynamicComponent
::
EvaluateInputs
(
const
SequenceModel
::
EvaluateState
&
state
,
const
NetworkStates
&
network_states
)
const
{
const
MutableMatrix
<
float
>
inputs
=
network_states
.
GetLocal
(
inputs_handle_
);
// Declared here for reuse in the loop below.
bool
is_out_of_bounds
=
false
;
Vector
<
float
>
embedding
;
// Handle forward and reverse iteration via a start index and increment.
int
target_index
=
sequence_model_
.
left_to_right
()
?
0
:
state
.
num_steps
-
1
;
const
int
target_increment
=
sequence_model_
.
left_to_right
()
?
1
:
-
1
;
for
(
size_t
step_index
=
0
;
step_index
<
state
.
num_steps
;
++
step_index
,
target_index
+=
target_increment
)
{
const
MutableVector
<
float
>
row
=
inputs
.
row
(
step_index
);
float
*
output
=
row
.
data
();
for
(
size_t
channel_id
=
0
;
channel_id
<
state
.
features
.
num_channels
();
++
channel_id
)
{
embedding
=
state
.
features
.
GetEmbedding
(
channel_id
,
target_index
);
memcpy
(
output
,
embedding
.
data
(),
embedding
.
size
()
*
sizeof
(
float
));
output
+=
embedding
.
size
();
}
for
(
size_t
channel_id
=
0
;
channel_id
<
state
.
links
.
num_channels
();
++
channel_id
)
{
state
.
links
.
Get
(
channel_id
,
target_index
,
&
embedding
,
&
is_out_of_bounds
);
memcpy
(
output
,
embedding
.
data
(),
embedding
.
size
()
*
sizeof
(
float
));
output
+=
embedding
.
size
();
}
DCHECK_EQ
(
output
,
row
.
end
());
}
return
Matrix
<
float
>
(
inputs
);
}
DRAGNN_RUNTIME_REGISTER_COMPONENT
(
SequenceBulkDynamicComponent
);
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/sequence_bulk_dynamic_component_test.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/bulk_network_unit.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/sequence_backend.h"
#include "dragnn/runtime/sequence_extractor.h"
#include "dragnn/runtime/sequence_linker.h"
#include "dragnn/runtime/sequence_predictor.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include <gmock/gmock.h>
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
using
::
testing
::
Return
;
constexpr
size_t
kNumSteps
=
50
;
constexpr
size_t
kFixedDim
=
11
;
constexpr
size_t
kFixedVocabularySize
=
123
;
constexpr
float
kFixedValue
=
0.5
;
constexpr
size_t
kLinkedDim
=
13
;
constexpr
float
kLinkedValue
=
1.25
;
constexpr
char
kPreviousComponentName
[]
=
"previous_component"
;
constexpr
char
kPreviousLayerName
[]
=
"previous_layer"
;
constexpr
char
kLogitsName
[]
=
"logits"
;
constexpr
size_t
kLogitsDim
=
kFixedDim
+
kLinkedDim
;
// Adds one to all inputs.
class
BulkAddOne
:
public
BulkNetworkUnit
{
public:
// Implements BulkNetworkUnit.
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
override
{
return
network_state_manager
->
AddLayer
(
kLogitsName
,
kLogitsDim
,
&
logits_handle_
);
}
tensorflow
::
Status
ValidateInputDimension
(
size_t
dimension
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
string
GetLogitsName
()
const
override
{
return
kLogitsName
;
}
tensorflow
::
Status
Evaluate
(
Matrix
<
float
>
inputs
,
SessionState
*
session_state
)
const
override
{
const
MutableMatrix
<
float
>
logits
=
session_state
->
network_states
.
GetLayer
(
logits_handle_
);
for
(
size_t
row
=
0
;
row
<
inputs
.
num_rows
();
++
row
)
{
for
(
size_t
column
=
0
;
column
<
inputs
.
num_columns
();
++
column
)
{
logits
.
row
(
row
)[
column
]
=
inputs
.
row
(
row
)[
column
]
+
1.0
;
}
}
return
tensorflow
::
Status
::
OK
();
}
private:
// Output logits.
LayerHandle
<
float
>
logits_handle_
;
};
DRAGNN_RUNTIME_REGISTER_BULK_NETWORK_UNIT
(
BulkAddOne
);
// A component that also prefers other but is triggered on the presence of a
// resource. This can be used to cause a component selection conflict.
class
ImTheWorst
:
public
Component
{
public:
// Implements Component.
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
Evaluate
(
SessionState
*
session_state
,
ComputeSession
*
compute_session
,
ComponentTrace
*
component_trace
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
bool
Supports
(
const
ComponentSpec
&
component_spec
,
const
string
&
normalized_builder_name
)
const
override
{
return
component_spec
.
resource_size
()
>
0
;
}
bool
PreferredTo
(
const
Component
&
other
)
const
override
{
return
false
;
}
};
DRAGNN_RUNTIME_REGISTER_COMPONENT
(
ImTheWorst
);
// Extractor that produces a sequence of zeros.
class
ExtractZeros
:
public
SequenceExtractor
{
public:
// Implements SequenceExtractor.
bool
Supports
(
const
FixedFeatureChannel
&
,
const
ComponentSpec
&
)
const
override
{
return
true
;
}
tensorflow
::
Status
Initialize
(
const
FixedFeatureChannel
&
,
const
ComponentSpec
&
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
GetIds
(
InputBatchCache
*
,
std
::
vector
<
int32
>
*
ids
)
const
override
{
ids
->
assign
(
kNumSteps
,
0
);
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR
(
ExtractZeros
);
// Linker that produces a sequence of zeros.
class
LinkZeros
:
public
SequenceLinker
{
public:
// Implements SequenceLinker.
bool
Supports
(
const
LinkedFeatureChannel
&
,
const
ComponentSpec
&
)
const
override
{
return
true
;
}
tensorflow
::
Status
Initialize
(
const
LinkedFeatureChannel
&
,
const
ComponentSpec
&
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
GetLinks
(
size_t
,
InputBatchCache
*
,
std
::
vector
<
int32
>
*
links
)
const
override
{
links
->
assign
(
kNumSteps
,
0
);
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER
(
LinkZeros
);
// Predictor that captures the logits.
class
CaptureLogits
:
public
SequencePredictor
{
public:
// Implements SequencePredictor.
bool
Supports
(
const
ComponentSpec
&
)
const
override
{
return
true
;
}
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
Predict
(
Matrix
<
float
>
logits
,
InputBatchCache
*
)
const
override
{
logits_
=
logits
;
return
tensorflow
::
Status
::
OK
();
}
// Returns the captured logits.
static
Matrix
<
float
>
GetCapturedLogits
()
{
return
logits_
;
}
private:
// Logits from the most recent call to Predict().
static
Matrix
<
float
>
logits_
;
};
Matrix
<
float
>
CaptureLogits
::
logits_
;
DRAGNN_RUNTIME_REGISTER_SEQUENCE_PREDICTOR
(
CaptureLogits
);
class
SequenceBulkDynamicComponentTest
:
public
NetworkTestBase
{
protected:
SequenceBulkDynamicComponentTest
()
{
EXPECT_CALL
(
compute_session_
,
GetInputBatchCache
())
.
WillRepeatedly
(
Return
(
&
input_
));
EXPECT_CALL
(
compute_session_
,
GetReadiedComponent
(
kTestComponentName
))
.
WillRepeatedly
(
Return
(
&
backend_
));
}
// Returns a spec that the network supports.
ComponentSpec
GetSupportedSpec
()
{
ComponentSpec
component_spec
;
component_spec
.
set_name
(
kTestComponentName
);
component_spec
.
set_num_actions
(
kLogitsDim
);
component_spec
.
mutable_network_unit
()
->
set_registered_name
(
"AddOne"
);
component_spec
.
mutable_backend
()
->
set_registered_name
(
"SequenceBackend"
);
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"SequenceBulkDynamicComponent"
);
auto
&
component_parameters
=
*
component_spec
.
mutable_component_builder
()
->
mutable_parameters
();
component_parameters
[
"sequence_extractors"
]
=
"ExtractZeros"
;
component_parameters
[
"sequence_linkers"
]
=
"LinkZeros"
;
component_parameters
[
"sequence_predictor"
]
=
"CaptureLogits"
;
FixedFeatureChannel
*
fixed_feature
=
component_spec
.
add_fixed_feature
();
fixed_feature
->
set_size
(
1
);
fixed_feature
->
set_embedding_dim
(
kFixedDim
);
fixed_feature
->
set_vocabulary_size
(
kFixedVocabularySize
);
LinkedFeatureChannel
*
linked_feature
=
component_spec
.
add_linked_feature
();
linked_feature
->
set_size
(
1
);
linked_feature
->
set_embedding_dim
(
-
1
);
linked_feature
->
set_source_component
(
kPreviousComponentName
);
linked_feature
->
set_source_layer
(
kPreviousLayerName
);
return
component_spec
;
}
// Creates a network unit, initializes it based on the |component_spec_text|,
// and evaluates it. On error, returns non-OK.
tensorflow
::
Status
Run
(
const
ComponentSpec
&
component_spec
)
{
AddComponent
(
kPreviousComponentName
);
AddLayer
(
kPreviousLayerName
,
kLinkedDim
);
AddComponent
(
kTestComponentName
);
AddFixedEmbeddingMatrix
(
0
,
kFixedVocabularySize
,
kFixedDim
,
kFixedValue
);
std
::
unique_ptr
<
Component
>
component
;
TF_RETURN_IF_ERROR
(
Component
::
CreateOrError
(
"SequenceBulkDynamicComponent"
,
&
component
));
TF_RETURN_IF_ERROR
(
component
->
Initialize
(
component_spec
,
&
variable_store_
,
&
network_state_manager_
,
&
extension_manager_
));
// Allocates network states for a few steps.
network_states_
.
Reset
(
&
network_state_manager_
);
StartComponent
(
kNumSteps
);
FillLayer
(
kPreviousComponentName
,
kPreviousLayerName
,
kLinkedValue
);
StartComponent
(
0
);
session_state_
.
extensions
.
Reset
(
&
extension_manager_
);
return
component
->
Evaluate
(
&
session_state_
,
&
compute_session_
,
nullptr
);
}
// Input batch injected into Evaluate() by default.
InputBatchCache
input_
;
// Backend injected into Evaluate().
SequenceBackend
backend_
;
};
// Tests that the supported spec is supported.
TEST_F
(
SequenceBulkDynamicComponentTest
,
Supported
)
{
const
ComponentSpec
component_spec
=
GetSupportedSpec
();
string
component_type
;
TF_ASSERT_OK
(
Component
::
Select
(
component_spec
,
&
component_type
));
EXPECT_EQ
(
component_type
,
"SequenceBulkDynamicComponent"
);
TF_ASSERT_OK
(
Run
(
component_spec
));
const
Matrix
<
float
>
logits
=
CaptureLogits
::
GetCapturedLogits
();
ASSERT_EQ
(
logits
.
num_rows
(),
kNumSteps
);
ASSERT_EQ
(
logits
.
num_columns
(),
kFixedDim
+
kLinkedDim
);
for
(
size_t
row
=
0
;
row
<
kNumSteps
;
++
row
)
{
size_t
column
=
0
;
for
(;
column
<
kFixedDim
;
++
column
)
{
EXPECT_EQ
(
logits
.
row
(
row
)[
column
],
kFixedValue
+
1.0
);
}
for
(;
column
<
kFixedDim
+
kLinkedDim
;
++
column
)
{
EXPECT_EQ
(
logits
.
row
(
row
)[
column
],
kLinkedValue
+
1.0
);
}
}
}
// Tests that links cannot be recurrent.
TEST_F
(
SequenceBulkDynamicComponentTest
,
ForbidRecurrences
)
{
ComponentSpec
component_spec
=
GetSupportedSpec
();
component_spec
.
mutable_linked_feature
(
0
)
->
set_source_component
(
kTestComponentName
);
string
component_type
;
EXPECT_THAT
(
Component
::
Select
(
component_spec
,
&
component_type
),
test
::
IsErrorWithSubstr
(
"Could not find a best spec for component"
));
}
// Tests that the component prefers others.
TEST_F
(
SequenceBulkDynamicComponentTest
,
PrefersOthers
)
{
ComponentSpec
component_spec
=
GetSupportedSpec
();
component_spec
.
add_resource
();
// Adding a resource triggers the ImTheWorst component, which also prefers
// itself and leads to a selection conflict.
string
component_type
;
EXPECT_THAT
(
Component
::
Select
(
component_spec
,
&
component_type
),
test
::
IsErrorWithSubstr
(
"both think they should be dis-preferred"
));
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
Prev
1
…
9
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment