Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
edea2b67
Commit
edea2b67
authored
May 11, 2018
by
Terry Koo
Browse files
Remove runtime because reasons.
parent
a4bb31d0
Changes
291
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
3456 deletions
+0
-3456
research/syntaxnet/dragnn/runtime/myelin/testdata/myelination_output/master-spec
...nn/runtime/myelin/testdata/myelination_output/master-spec
+0
-160
research/syntaxnet/dragnn/runtime/myelin/testdata/myelination_output/rnn.flow
...ragnn/runtime/myelin/testdata/myelination_output/rnn.flow
+0
-0
research/syntaxnet/dragnn/runtime/myelin/testdata/myelination_output/tagger.flow
...nn/runtime/myelin/testdata/myelination_output/tagger.flow
+0
-0
research/syntaxnet/dragnn/runtime/network_states.cc
research/syntaxnet/dragnn/runtime/network_states.cc
+0
-197
research/syntaxnet/dragnn/runtime/network_states.h
research/syntaxnet/dragnn/runtime/network_states.h
+0
-422
research/syntaxnet/dragnn/runtime/network_states_test.cc
research/syntaxnet/dragnn/runtime/network_states_test.cc
+0
-508
research/syntaxnet/dragnn/runtime/network_unit.cc
research/syntaxnet/dragnn/runtime/network_unit.cc
+0
-43
research/syntaxnet/dragnn/runtime/network_unit.h
research/syntaxnet/dragnn/runtime/network_unit.h
+0
-95
research/syntaxnet/dragnn/runtime/network_unit_base.cc
research/syntaxnet/dragnn/runtime/network_unit_base.cc
+0
-171
research/syntaxnet/dragnn/runtime/network_unit_base.h
research/syntaxnet/dragnn/runtime/network_unit_base.h
+0
-137
research/syntaxnet/dragnn/runtime/network_unit_base_test.cc
research/syntaxnet/dragnn/runtime/network_unit_base_test.cc
+0
-403
research/syntaxnet/dragnn/runtime/network_unit_test.cc
research/syntaxnet/dragnn/runtime/network_unit_test.cc
+0
-82
research/syntaxnet/dragnn/runtime/operands.cc
research/syntaxnet/dragnn/runtime/operands.cc
+0
-142
research/syntaxnet/dragnn/runtime/operands.h
research/syntaxnet/dragnn/runtime/operands.h
+0
-236
research/syntaxnet/dragnn/runtime/operands_test.cc
research/syntaxnet/dragnn/runtime/operands_test.cc
+0
-350
research/syntaxnet/dragnn/runtime/recurrent_sequence_linkers.cc
...ch/syntaxnet/dragnn/runtime/recurrent_sequence_linkers.cc
+0
-96
research/syntaxnet/dragnn/runtime/recurrent_sequence_linkers_test.cc
...ntaxnet/dragnn/runtime/recurrent_sequence_linkers_test.cc
+0
-151
research/syntaxnet/dragnn/runtime/reversed_sequence_linker.cc
...arch/syntaxnet/dragnn/runtime/reversed_sequence_linker.cc
+0
-76
research/syntaxnet/dragnn/runtime/reversed_sequence_linker_test.cc
...syntaxnet/dragnn/runtime/reversed_sequence_linker_test.cc
+0
-129
research/syntaxnet/dragnn/runtime/select_best_component_transformer.cc
...axnet/dragnn/runtime/select_best_component_transformer.cc
+0
-58
No files found.
Too many changes to show.
To preserve performance only
291 of 291+
files are displayed.
Plain diff
Email patch
research/syntaxnet/dragnn/runtime/myelin/testdata/myelination_output/master-spec
deleted
100644 → 0
View file @
a4bb31d0
component {
name: "rnn"
transition_system {
registered_name: "shift-only"
parameters {
key: "left_to_right"
value: "false"
}
parameters {
key: "parser_skip_deterministic"
value: "false"
}
}
resource {
name: "words-embedding-input"
part {
file_format: "tf-records"
record_format: "syntaxnet.TokenEmbedding"
}
}
resource {
name: "words-vocab-input"
part {
file_format: "text"
record_format: ""
}
}
resource {
name: "char-ngram-map"
part {
file_format: "text"
record_format: ""
}
}
resource {
name: "word-map"
part {
file_format: "text"
record_format: ""
}
}
resource {
name: "label-map"
part {
file_format: "text"
record_format: ""
}
}
resource {
name: "myelin-flow"
part {
file_format: "model"
record_format: "sling.myelin.Flow"
}
}
fixed_feature {
name: "char_ngrams"
fml: "input.token { offset(-1).char-ngram(min-length=1,max-length=3,mark-boundaries=true) offset(0).char-ngram(min-length=1,max-length=3,mark-boundaries=true) offset(1).char-ngram(min-length=1,max-length=3,mark-boundaries=true) }"
embedding_dim: -1
vocabulary_size: 25788
size: 3
}
fixed_feature {
name: "words"
fml: "input.token.word(min-freq=2)"
embedding_dim: -1
vocabulary_size: 23769
size: 1
}
network_unit {
registered_name: "LSTMNetwork"
parameters {
key: "hidden_layer_sizes"
value: "128"
}
parameters {
key: "omit_logits"
value: "true"
}
}
backend {
registered_name: "SyntaxNetComponent"
}
num_actions: 1
attention_component: ""
component_builder {
registered_name: "MyelinDynamicComponent"
}
}
component {
name: "tagger"
transition_system {
registered_name: "tagger"
parameters {
key: "parser_skip_deterministic"
value: "false"
}
}
resource {
name: "tag-map"
part {
file_format: "text"
record_format: ""
}
}
resource {
name: "tag-to-category"
part {
file_format: "text"
record_format: ""
}
}
resource {
name: "label-map"
part {
file_format: "text"
record_format: ""
}
}
resource {
name: "myelin-flow"
part {
file_format: "model"
record_format: "sling.myelin.Flow"
}
}
linked_feature {
name: "recurrence"
fml: "bias(0)"
embedding_dim: -1
size: 1
source_component: "tagger"
source_translator: "history"
source_layer: "layer_0"
}
linked_feature {
name: "rnn"
fml: "input.focus"
embedding_dim: -1
size: 1
source_component: "rnn"
source_translator: "reverse-token"
source_layer: "layer_0"
}
network_unit {
registered_name: "FeedForwardNetwork"
parameters {
key: "hidden_layer_sizes"
value: "64,64"
}
}
backend {
registered_name: "SyntaxNetComponent"
}
num_actions: 45
attention_component: ""
component_builder {
registered_name: "MyelinDynamicComponent"
}
}
research/syntaxnet/dragnn/runtime/myelin/testdata/myelination_output/rnn.flow
deleted
100644 → 0
View file @
a4bb31d0
File deleted
research/syntaxnet/dragnn/runtime/myelin/testdata/myelination_output/tagger.flow
deleted
100644 → 0
View file @
a4bb31d0
File deleted
research/syntaxnet/dragnn/runtime/network_states.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/network_states.h"
#include "tensorflow/core/lib/core/errors.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Returns the first value in |container| whose ".name" field is |name|, or null
// if not found.
template
<
class
Container
>
const
typename
Container
::
value_type
*
Find
(
const
Container
&
container
,
const
string
&
name
)
{
for
(
auto
&
value
:
container
)
{
if
(
value
.
name
==
name
)
return
&
value
;
}
return
nullptr
;
}
}
// namespace
tensorflow
::
Status
NetworkStateManager
::
AddComponent
(
const
string
&
name
)
{
if
(
Find
(
components_
,
name
)
!=
nullptr
)
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Component '"
,
name
,
"' already exists"
);
}
// Success; make modifications.
components_
.
emplace_back
(
name
);
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
NetworkStateManager
::
AddLayerImpl
(
const
string
&
name
,
std
::
type_index
type
,
bool
is_pairwise
,
size_t
bytes
,
size_t
*
component_index
,
OperandHandle
*
operand_handle
)
{
if
(
components_
.
empty
())
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"No current component"
);
}
ComponentConfig
&
component
=
components_
.
back
();
if
(
Find
(
component
.
layers
,
name
)
!=
nullptr
)
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Layer '"
,
name
,
"' already exists in component '"
,
component
.
name
,
"'"
);
}
if
(
component
.
aliases
.
find
(
name
)
!=
component
.
aliases
.
end
())
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Layer '"
,
name
,
"' conflicts with an existing alias in component '"
,
component
.
name
,
"'"
);
}
// Success; make modifications.
const
OperandType
operand_type
=
is_pairwise
?
OperandType
::
kPairwise
:
OperandType
::
kStepwise
;
*
component_index
=
components_
.
size
()
-
1
;
*
operand_handle
=
component
.
manager
.
Add
({
operand_type
,
bytes
});
component
.
layers
.
emplace_back
(
name
,
type
,
*
operand_handle
);
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
NetworkStateManager
::
AddLayerAlias
(
const
string
&
alias
,
const
string
&
name
)
{
if
(
components_
.
empty
())
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"No current component"
);
}
ComponentConfig
&
component
=
components_
.
back
();
if
(
Find
(
component
.
layers
,
name
)
==
nullptr
)
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Target layer '"
,
name
,
"' of alias '"
,
alias
,
"' does not exist in component '"
,
component
.
name
,
"'"
);
}
if
(
Find
(
component
.
layers
,
alias
)
!=
nullptr
)
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Alias '"
,
alias
,
"' conflicts with an existing layer in component '"
,
component
.
name
,
"'"
);
}
if
(
component
.
aliases
.
find
(
alias
)
!=
component
.
aliases
.
end
())
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Alias '"
,
alias
,
"' already exists in component '"
,
component
.
name
,
"'"
);
}
// Success; make modifications.
component
.
aliases
[
alias
]
=
name
;
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
NetworkStateManager
::
AddLocalImpl
(
const
OperandSpec
&
spec
,
OperandHandle
*
handle
)
{
if
(
components_
.
empty
())
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"No current component"
);
}
ComponentConfig
&
component
=
components_
.
back
();
// Success; make modifications.
*
handle
=
component
.
manager
.
Add
(
spec
);
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
NetworkStateManager
::
LookupLayerImpl
(
const
string
&
component_name
,
const
string
&
layer_name_or_alias
,
std
::
type_index
type
,
bool
is_pairwise
,
size_t
*
bytes
,
size_t
*
component_index
,
OperandHandle
*
operand_handle
)
const
{
const
ComponentConfig
*
component
=
Find
(
components_
,
component_name
);
if
(
component
==
nullptr
)
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Unknown component '"
,
component_name
,
"'"
);
}
// If necessary, resolve a layer alias into a layer name. Note that aliases
// are non-transitive, since AddLayerAlias() requires that the target of the
// alias is a layer.
const
auto
it
=
component
->
aliases
.
find
(
layer_name_or_alias
);
const
string
&
layer_name
=
it
!=
component
->
aliases
.
end
()
?
it
->
second
:
layer_name_or_alias
;
const
LayerConfig
*
layer
=
Find
(
component
->
layers
,
layer_name
);
if
(
layer
==
nullptr
)
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Unknown layer '"
,
layer_name
,
"' in component '"
,
component_name
,
"'"
);
}
if
(
layer
->
type
!=
type
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Layer '"
,
layer_name
,
"' in component '"
,
component_name
,
"' does not match its expected type"
);
}
const
OperandType
required_type
=
is_pairwise
?
OperandType
::
kPairwise
:
OperandType
::
kStepwise
;
const
OperandSpec
&
operand_spec
=
component
->
manager
.
spec
(
layer
->
handle
);
if
(
operand_spec
.
type
!=
required_type
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Layer '"
,
layer_name
,
"' in component '"
,
component_name
,
"' does not match its expected OperandType"
);
}
// Success; make modifications.
*
bytes
=
operand_spec
.
size
;
*
component_index
=
component
-
components_
.
data
();
*
operand_handle
=
layer
->
handle
;
return
tensorflow
::
Status
::
OK
();
}
void
NetworkStates
::
Reset
(
const
NetworkStateManager
*
manager
)
{
manager_
=
manager
;
num_active_components_
=
0
;
// Never shrink the |component_operands_|, to avoid deallocating (and then
// eventually reallocating) operand arrays.
if
(
manager_
->
components_
.
size
()
>
component_operands_
.
size
())
{
component_operands_
.
resize
(
manager_
->
components_
.
size
());
}
}
tensorflow
::
Status
NetworkStates
::
StartNextComponent
(
size_t
pre_allocate_num_steps
)
{
if
(
manager_
==
nullptr
)
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"No manager"
);
}
if
(
num_active_components_
>=
manager_
->
components_
.
size
())
{
return
tensorflow
::
errors
::
OutOfRange
(
"No next component"
);
}
// Success; make modifications.
const
OperandManager
*
operand_manager
=
&
manager_
->
components_
[
num_active_components_
].
manager
;
component_operands_
[
num_active_components_
].
Reset
(
operand_manager
,
pre_allocate_num_steps
);
++
num_active_components_
;
return
tensorflow
::
Status
::
OK
();
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/network_states.h
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Utils for declaring, allocating, and retrieving network states, similar to
// the "NetworkState" class and the "network_states" argument to the build_*()
// methods of ComponentBuilderBase; see component.py.
//
// In brief, a DRAGNN network consists of a sequence of named components, each
// of which produces a set of named output layers. Each component can access
// its own layers as well as those of preceding components. Components can also
// access "local operands", which are like layers but private to that particular
// component. Local operands can be useful for, e.g., caching an intermediate
// result in a complex computation.
//
// For example, suppose a network has two components: "tagger" and "parser",
// where the parser uses the hidden activations of the tagger. In this case,
// the tagger can add a layer called "hidden" at init time and fill that layer
// at processing time. Corespondingly, the parser can look for a layer called
// "hidden" in the "tagger" component at init time, and read the activations at
// processing time. (Note that for convenience, such links should be handled
// using the utils in linked_embeddings.h).
//
// As another example, suppose we are implementing an LSTM and we wish to keep
// the cell state private. In this case, the LSTM component could add a layer
// for exporting the hidden activations and a local matrix for the sequence of
// cell states. A more compact approach is to use two local vectors instead,
// one for even steps and the other for odd steps.
#ifndef DRAGNN_RUNTIME_NETWORK_STATES_H_
#define DRAGNN_RUNTIME_NETWORK_STATES_H_
#include <stddef.h>
#include <stdint.h>
#include <map>
#include <string>
#include <typeindex>
#include <typeinfo>
#include <vector>
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/operands.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// Opaque handles used to access typed layers or local operands.
template
<
class
T
>
class
LayerHandle
;
template
<
class
T
>
class
PairwiseLayerHandle
;
template
<
class
T
>
class
LocalVectorHandle
;
template
<
class
T
>
class
LocalMatrixHandle
;
// A class that manages the state of a DRAGNN network and associates each layer
// and local operand with a handle. Layer and local operand contents can be
// retrieved using these handles; see NetworkStates below.
class
NetworkStateManager
{
public:
// Creates an empty manager.
NetworkStateManager
()
=
default
;
// Adds a component named |name| and makes it the current component. The
// |name| must be unique in the network. Components are sequenced in the
// order they are added. On error, returns non-OK and modifies nothing.
tensorflow
::
Status
AddComponent
(
const
string
&
name
);
// Adds a layer named |name| to the current component and sets |handle| to its
// handle. The |name| must be unique in the current component. The layer is
// realized as a Matrix<T> with one row per step and |dimension| columns. On
// error, returns non-OK and modifies nothing.
template
<
class
T
>
tensorflow
::
Status
AddLayer
(
const
string
&
name
,
size_t
dimension
,
LayerHandle
<
T
>
*
handle
);
// As above, but for pairwise layers.
template
<
class
T
>
tensorflow
::
Status
AddLayer
(
const
string
&
name
,
size_t
dimension
,
PairwiseLayerHandle
<
T
>
*
handle
);
// As above, but for a local Vector<T> or Matrix<T> operand. The operand is
// "local" in the sense that only the caller knows its handle.
template
<
class
T
>
tensorflow
::
Status
AddLocal
(
size_t
dimension
,
LocalVectorHandle
<
T
>
*
handle
);
template
<
class
T
>
tensorflow
::
Status
AddLocal
(
size_t
dimension
,
LocalMatrixHandle
<
T
>
*
handle
);
// Makes |alias| an alias of the layer named |name| in the current component,
// so that lookups of |alias| resolve to |name|. The |name| must already
// exist as a layer, and layer names and aliases must be unique within each
// component. On error, returns non-OK and modifies nothing.
tensorflow
::
Status
AddLayerAlias
(
const
string
&
alias
,
const
string
&
name
);
// Finds the layer that matches |layer_name_or_alias| in the component named
// |component_name|. Sets |dimension| to its dimension and |handle| to its
// handle. On error, returns non-OK and modifies nothing.
template
<
class
T
>
tensorflow
::
Status
LookupLayer
(
const
string
&
component_name
,
const
string
&
layer_name_or_alias
,
size_t
*
dimension
,
LayerHandle
<
T
>
*
handle
)
const
;
// As above, but for pairwise layers.
template
<
class
T
>
tensorflow
::
Status
LookupLayer
(
const
string
&
component_name
,
const
string
&
layer_name_or_alias
,
size_t
*
dimension
,
PairwiseLayerHandle
<
T
>
*
handle
)
const
;
private:
friend
class
NetworkStates
;
// Configuration information for a layer.
struct
LayerConfig
{
// Creates a config for a layer with the |name|, |type| ID, and |handle|.
LayerConfig
(
const
string
&
name
,
std
::
type_index
type
,
OperandHandle
handle
)
:
name
(
name
),
type
(
type
),
handle
(
handle
)
{}
// Name of the layer.
string
name
;
// Type ID of the layer contents.
std
::
type_index
type
;
// Handle of the operand that holds the layer contents.
OperandHandle
handle
;
};
// Configuration information for a component.
struct
ComponentConfig
{
// Creates an empty config for a component with the |name|.
explicit
ComponentConfig
(
const
string
&
name
)
:
name
(
name
)
{}
// Name of the component.
string
name
;
// Manager for the operands used by the component.
OperandManager
manager
;
// Configuration of each layer produced by the component.
std
::
vector
<
LayerConfig
>
layers
;
// Mapping from layer alias to layer name in the component.
std
::
map
<
string
,
string
>
aliases
;
};
// Implements the non-templated part of AddLayer(). Adds a layer with the
// |name|, |type| ID, and size in |bytes|. Sets the |component_index| and
// |operand_handle| according to the containing component and operand. If
// |is_pairwise| is true, then the new layer is pairwise (vs stepwise). On
// error, returns non-OK and modifies nothing.
tensorflow
::
Status
AddLayerImpl
(
const
string
&
name
,
std
::
type_index
type
,
bool
is_pairwise
,
size_t
bytes
,
size_t
*
component_index
,
OperandHandle
*
operand_handle
);
// Implements the non-templated portion of AddLocal*(). Adds a local operand
// with the |spec| and sets |handle| to its handle. On error, returns non-OK
// and modifies nothing.
tensorflow
::
Status
AddLocalImpl
(
const
OperandSpec
&
spec
,
OperandHandle
*
handle
);
// Implements the non-templated portion of LookupLayer(). Finds the layer
// that matches the |component_name| and |layer_name_or_alias|. That layer
// must match the |type| ID. Sets |bytes| to its size, |component_index| to
// the index of its containing component, and |operand_handle| to the handle
// of its underlying operand. If |is_pairwise| is true, then the layer must
// be pairwise (vs stepwise). On error, returns non-OK and modifies nothing.
tensorflow
::
Status
LookupLayerImpl
(
const
string
&
component_name
,
const
string
&
layer_name_or_alias
,
std
::
type_index
type
,
bool
is_pairwise
,
size_t
*
bytes
,
size_t
*
component_index
,
OperandHandle
*
operand_handle
)
const
;
// Ordered list of configurations for the components in the network.
std
::
vector
<
ComponentConfig
>
components_
;
};
// A set of network states. The structure of the network is configured by a
// NetworkStateManager, and layer and local operand contents can be accessed
// using the handles produced by the manager.
//
// Multiple NetworkStates instances can share the same NetworkStateManager. In
// addition, a NetworkStates instance can be reused by repeatedly Reset()-ing
// it, potentially with different NetworkStateManagers. Such reuse can reduce
// allocation overhead.
class
NetworkStates
{
public:
// Creates an uninitialized set of states.
NetworkStates
()
=
default
;
// Resets this to an empty set configured by the |manager|. The |manager|
// must live until this is destroyed or Reset(), and should not be modified
// during that time. No current component is set; call StartNextComponent()
// to start the first component.
void
Reset
(
const
NetworkStateManager
*
manager
);
// Starts the next component and makes it the current component. Initially,
// the component has zero steps but more can be added using AddStep(). Uses
// |pre_allocate_num_steps| to pre-allocate storage; see Operands::Reset().
// On error, returns non-OK and modifies nothing.
tensorflow
::
Status
StartNextComponent
(
size_t
pre_allocate_num_steps
);
// Adds one or more steps to the current component. Invalidates all
// previously-returned matrices of the current component.
void
AddStep
()
{
AddSteps
(
1
);
}
void
AddSteps
(
size_t
num_steps
);
// Returns the layer associated with the |handle|.
template
<
class
T
>
MutableMatrix
<
T
>
GetLayer
(
LayerHandle
<
T
>
handle
)
const
;
// Returns the pairwise layer associated with the |handle|.
template
<
class
T
>
MutableMatrix
<
T
>
GetLayer
(
PairwiseLayerHandle
<
T
>
handle
)
const
;
// Returns the local vector or matrix associated with the |handle| in the
// current component.
template
<
class
T
>
MutableVector
<
T
>
GetLocal
(
LocalVectorHandle
<
T
>
handle
)
const
;
template
<
class
T
>
MutableMatrix
<
T
>
GetLocal
(
LocalMatrixHandle
<
T
>
handle
)
const
;
private:
// Manager of this set of network states.
const
NetworkStateManager
*
manager_
=
nullptr
;
// Number of active components in the |component_operands_|.
size_t
num_active_components_
=
0
;
// Ordered list of per-component operands. Only the first
// |num_active_components_| entries are valid.
std
::
vector
<
Operands
>
component_operands_
;
};
// Implementation details below.
// An opaque handle to a typed layer of some component.
template
<
class
T
>
class
LayerHandle
{
public:
static_assert
(
IsAlignable
<
T
>
(),
"T must be alignable"
);
// Creates an invalid handle.
LayerHandle
()
=
default
;
private:
friend
class
NetworkStateManager
;
friend
class
NetworkStates
;
// Index of the containing component in the network state manager.
size_t
component_index_
=
SIZE_MAX
;
// Handle of the operand holding the layer.
OperandHandle
operand_handle_
;
};
// An opaque handle to a typed pairwise layer of some component.
template
<
class
T
>
class
PairwiseLayerHandle
{
public:
static_assert
(
IsAlignable
<
T
>
(),
"T must be alignable"
);
// Creates an invalid handle.
PairwiseLayerHandle
()
=
default
;
private:
friend
class
NetworkStateManager
;
friend
class
NetworkStates
;
// Index of the containing component in the network state manager.
size_t
component_index_
=
SIZE_MAX
;
// Handle of the operand holding the layer.
OperandHandle
operand_handle_
;
};
// An opaque handle to a typed local operand of some component.
template
<
class
T
>
class
LocalVectorHandle
{
public:
static_assert
(
IsAlignable
<
T
>
(),
"T must be alignable"
);
// Creates an invalid handle.
LocalVectorHandle
()
=
default
;
private:
friend
class
NetworkStateManager
;
friend
class
NetworkStates
;
// Handle of the local operand.
OperandHandle
operand_handle_
;
};
// An opaque handle to a typed local operand of some component.
template
<
class
T
>
class
LocalMatrixHandle
{
public:
static_assert
(
IsAlignable
<
T
>
(),
"T must be alignable"
);
// Creates an invalid handle.
LocalMatrixHandle
()
=
default
;
private:
friend
class
NetworkStateManager
;
friend
class
NetworkStates
;
// Handle of the local operand.
OperandHandle
operand_handle_
;
};
template
<
class
T
>
tensorflow
::
Status
NetworkStateManager
::
AddLayer
(
const
string
&
name
,
size_t
dimension
,
LayerHandle
<
T
>
*
handle
)
{
return
AddLayerImpl
(
name
,
std
::
type_index
(
typeid
(
T
)),
/*is_pairwise=*/
false
,
dimension
*
sizeof
(
T
),
&
handle
->
component_index_
,
&
handle
->
operand_handle_
);
}
template
<
class
T
>
tensorflow
::
Status
NetworkStateManager
::
AddLayer
(
const
string
&
name
,
size_t
dimension
,
PairwiseLayerHandle
<
T
>
*
handle
)
{
return
AddLayerImpl
(
name
,
std
::
type_index
(
typeid
(
T
)),
/*is_pairwise=*/
true
,
dimension
*
sizeof
(
T
),
&
handle
->
component_index_
,
&
handle
->
operand_handle_
);
}
template
<
class
T
>
tensorflow
::
Status
NetworkStateManager
::
AddLocal
(
size_t
dimension
,
LocalVectorHandle
<
T
>
*
handle
)
{
return
AddLocalImpl
({
OperandType
::
kSingular
,
dimension
*
sizeof
(
T
)},
&
handle
->
operand_handle_
);
}
template
<
class
T
>
tensorflow
::
Status
NetworkStateManager
::
AddLocal
(
size_t
dimension
,
LocalMatrixHandle
<
T
>
*
handle
)
{
return
AddLocalImpl
({
OperandType
::
kStepwise
,
dimension
*
sizeof
(
T
)},
&
handle
->
operand_handle_
);
}
template
<
class
T
>
tensorflow
::
Status
NetworkStateManager
::
LookupLayer
(
const
string
&
component_name
,
const
string
&
layer_name_or_alias
,
size_t
*
dimension
,
LayerHandle
<
T
>
*
handle
)
const
{
TF_RETURN_IF_ERROR
(
LookupLayerImpl
(
component_name
,
layer_name_or_alias
,
std
::
type_index
(
typeid
(
T
)),
/*is_pairwise=*/
false
,
dimension
,
&
handle
->
component_index_
,
&
handle
->
operand_handle_
));
DCHECK_EQ
(
*
dimension
%
sizeof
(
T
),
0
);
*
dimension
/=
sizeof
(
T
);
// bytes => Ts
return
tensorflow
::
Status
::
OK
();
}
template
<
class
T
>
tensorflow
::
Status
NetworkStateManager
::
LookupLayer
(
const
string
&
component_name
,
const
string
&
layer_name_or_alias
,
size_t
*
dimension
,
PairwiseLayerHandle
<
T
>
*
handle
)
const
{
TF_RETURN_IF_ERROR
(
LookupLayerImpl
(
component_name
,
layer_name_or_alias
,
std
::
type_index
(
typeid
(
T
)),
/*is_pairwise=*/
true
,
dimension
,
&
handle
->
component_index_
,
&
handle
->
operand_handle_
));
DCHECK_EQ
(
*
dimension
%
sizeof
(
T
),
0
);
*
dimension
/=
sizeof
(
T
);
// bytes => Ts
return
tensorflow
::
Status
::
OK
();
}
inline
void
NetworkStates
::
AddSteps
(
size_t
num_steps
)
{
component_operands_
[
num_active_components_
-
1
].
AddSteps
(
num_steps
);
}
template
<
class
T
>
MutableMatrix
<
T
>
NetworkStates
::
GetLayer
(
LayerHandle
<
T
>
handle
)
const
{
return
MutableMatrix
<
T
>
(
component_operands_
[
handle
.
component_index_
].
GetStepwise
(
handle
.
operand_handle_
));
}
template
<
class
T
>
MutableMatrix
<
T
>
NetworkStates
::
GetLayer
(
PairwiseLayerHandle
<
T
>
handle
)
const
{
return
MutableMatrix
<
T
>
(
component_operands_
[
handle
.
component_index_
].
GetPairwise
(
handle
.
operand_handle_
));
}
template
<
class
T
>
MutableVector
<
T
>
NetworkStates
::
GetLocal
(
LocalVectorHandle
<
T
>
handle
)
const
{
return
MutableVector
<
T
>
(
component_operands_
[
num_active_components_
-
1
].
GetSingular
(
handle
.
operand_handle_
));
}
template
<
class
T
>
MutableMatrix
<
T
>
NetworkStates
::
GetLocal
(
LocalMatrixHandle
<
T
>
handle
)
const
{
return
MutableMatrix
<
T
>
(
component_operands_
[
num_active_components_
-
1
].
GetStepwise
(
handle
.
operand_handle_
));
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_NETWORK_STATES_H_
research/syntaxnet/dragnn/runtime/network_states_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/network_states.h"
#include <stddef.h>
#include <string.h>
#include <string>
#include <vector>
#include "dragnn/core/test/generic.h"
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/math/types.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Expects that two objects have identical bit representations.
template
<
class
T
>
void
ExpectBitwiseEqual
(
const
T
&
object1
,
const
T
&
object2
)
{
EXPECT_EQ
(
memcmp
(
&
object1
,
&
object2
,
sizeof
(
T
)),
0
);
}
// Expects that the |matrix| has the given dimensions.
template
<
class
T
>
void
ExpectDimensions
(
MutableMatrix
<
T
>
matrix
,
size_t
num_rows
,
size_t
num_columns
)
{
EXPECT_EQ
(
matrix
.
num_rows
(),
num_rows
);
EXPECT_EQ
(
matrix
.
num_columns
(),
num_columns
);
}
// Sets the |vector| to |size| copies of the |value|.
template
<
class
T
>
void
Fill
(
MutableVector
<
T
>
vector
,
size_t
size
,
T
value
)
{
ASSERT_EQ
(
vector
.
size
(),
size
);
for
(
T
&
element
:
vector
)
element
=
value
;
}
// Expects that the |vector| contains |size| copies of the |expected_value|.
template
<
class
T
>
void
ExpectFilled
(
MutableVector
<
T
>
vector
,
size_t
size
,
T
expected_value
)
{
ASSERT_EQ
(
vector
.
size
(),
size
);
for
(
const
T
element
:
vector
)
EXPECT_EQ
(
element
,
expected_value
);
}
// Tests that NetworkStateManager can add a named component.
TEST
(
NetworkStateManagerTest
,
AddComponent
)
{
NetworkStateManager
manager
;
TF_EXPECT_OK
(
manager
.
AddComponent
(
"foo/bar"
));
EXPECT_THAT
(
manager
.
AddComponent
(
"foo/bar"
),
test
::
IsErrorWithSubstr
(
"Component 'foo/bar' already exists"
));
// Empty component name is weird, but OK.
TF_EXPECT_OK
(
manager
.
AddComponent
(
""
));
EXPECT_THAT
(
manager
.
AddComponent
(
""
),
test
::
IsErrorWithSubstr
(
"Component '' already exists"
));
}
// Tests that NetworkStateManager can add a named layer to the current
// component.
TEST
(
NetworkStateManagerTest
,
AddLayer
)
{
NetworkStateManager
manager
;
LayerHandle
<
float
>
unused_layer_handle
;
EXPECT_THAT
(
manager
.
AddLayer
(
"layer"
,
1
,
&
unused_layer_handle
),
test
::
IsErrorWithSubstr
(
"No current component"
));
TF_EXPECT_OK
(
manager
.
AddComponent
(
"component"
));
TF_EXPECT_OK
(
manager
.
AddLayer
(
"layer"
,
2
,
&
unused_layer_handle
));
EXPECT_THAT
(
manager
.
AddLayer
(
"layer"
,
2
,
&
unused_layer_handle
),
test
::
IsErrorWithSubstr
(
"Layer 'layer' already exists in component 'component'"
));
}
// Tests that NetworkStateManager can add a named pairwise layer to the current
// component.
TEST
(
NetworkStateManagerTest
,
AddLayerPairwise
)
{
NetworkStateManager
manager
;
PairwiseLayerHandle
<
float
>
unused_layer_handle
;
EXPECT_THAT
(
manager
.
AddLayer
(
"layer"
,
1
,
&
unused_layer_handle
),
test
::
IsErrorWithSubstr
(
"No current component"
));
TF_EXPECT_OK
(
manager
.
AddComponent
(
"component"
));
TF_EXPECT_OK
(
manager
.
AddLayer
(
"layer"
,
2
,
&
unused_layer_handle
));
EXPECT_THAT
(
manager
.
AddLayer
(
"layer"
,
2
,
&
unused_layer_handle
),
test
::
IsErrorWithSubstr
(
"Layer 'layer' already exists in component 'component'"
));
}
// Tests that NetworkStateManager can add an alias to an existing layer. Also
// tests that layer and alias names are required to be unique.
TEST
(
NetworkStateManagerTest
,
AddLayerAlias
)
{
NetworkStateManager
manager
;
LayerHandle
<
float
>
unused_layer_handle
;
EXPECT_THAT
(
manager
.
AddLayerAlias
(
"alias"
,
"layer"
),
test
::
IsErrorWithSubstr
(
"No current component"
));
TF_EXPECT_OK
(
manager
.
AddComponent
(
"component"
));
EXPECT_THAT
(
manager
.
AddLayerAlias
(
"alias"
,
"layer"
),
test
::
IsErrorWithSubstr
(
"Target layer 'layer' of alias 'alias' does not "
"exist in component 'component'"
));
TF_EXPECT_OK
(
manager
.
AddLayer
(
"layer"
,
2
,
&
unused_layer_handle
));
TF_EXPECT_OK
(
manager
.
AddLayerAlias
(
"alias"
,
"layer"
));
EXPECT_THAT
(
manager
.
AddLayerAlias
(
"alias"
,
"layer"
),
test
::
IsErrorWithSubstr
(
"Alias 'alias' already exists in component 'component'"
));
EXPECT_THAT
(
manager
.
AddLayer
(
"alias"
,
2
,
&
unused_layer_handle
),
test
::
IsErrorWithSubstr
(
"Layer 'alias' conflicts with an existing alias "
"in component 'component'"
));
TF_EXPECT_OK
(
manager
.
AddLayer
(
"layer2"
,
2
,
&
unused_layer_handle
));
EXPECT_THAT
(
manager
.
AddLayerAlias
(
"layer2"
,
"layer"
),
test
::
IsErrorWithSubstr
(
"Alias 'layer2' conflicts with an existing layer "
"in component 'component'"
));
}
// Tests that NetworkStateManager can add a local matrix or vector to the
// current component.
TEST
(
NetworkStateManagerTest
,
AddLocal
)
{
NetworkStateManager
manager
;
LocalVectorHandle
<
float
>
unused_local_vector_handle
;
LocalMatrixHandle
<
float
>
unused_local_matrix_handle
;
EXPECT_THAT
(
manager
.
AddLocal
(
11
,
&
unused_local_matrix_handle
),
test
::
IsErrorWithSubstr
(
"No current component"
));
TF_EXPECT_OK
(
manager
.
AddComponent
(
"component"
));
TF_EXPECT_OK
(
manager
.
AddLocal
(
22
,
&
unused_local_matrix_handle
));
TF_EXPECT_OK
(
manager
.
AddLocal
(
33
,
&
unused_local_vector_handle
));
}
// Tests that NetworkStateManager can look up existing layers or aliases, and
// fails on invalid layer or component names and for mismatched types.
TEST
(
NetworkStateManagerTest
,
LookupLayer
)
{
NetworkStateManager
manager
;
LayerHandle
<
char
>
char_handle
;
LayerHandle
<
int16
>
int16_handle
;
LayerHandle
<
uint16
>
uint16_handle
;
PairwiseLayerHandle
<
char
>
pairwise_char_handle
;
size_t
dimension
=
0
;
// Add some typed layers and aliases.
TF_ASSERT_OK
(
manager
.
AddComponent
(
"foo"
));
TF_ASSERT_OK
(
manager
.
AddLayer
(
"char"
,
5
,
&
char_handle
));
TF_ASSERT_OK
(
manager
.
AddLayer
(
"int16"
,
7
,
&
int16_handle
));
TF_ASSERT_OK
(
manager
.
AddLayerAlias
(
"char_alias"
,
"char"
));
TF_ASSERT_OK
(
manager
.
AddLayerAlias
(
"int16_alias"
,
"int16"
));
TF_ASSERT_OK
(
manager
.
AddComponent
(
"bar"
));
TF_ASSERT_OK
(
manager
.
AddLayer
(
"uint16"
,
11
,
&
uint16_handle
));
TF_ASSERT_OK
(
manager
.
AddLayer
(
"pairwise_char"
,
13
,
&
pairwise_char_handle
));
TF_ASSERT_OK
(
manager
.
AddLayerAlias
(
"uint16_alias"
,
"uint16"
));
TF_ASSERT_OK
(
manager
.
AddLayerAlias
(
"pairwise_char_alias"
,
"pairwise_char"
));
// Try looking up unknown components.
EXPECT_THAT
(
manager
.
LookupLayer
(
"missing"
,
"char"
,
&
dimension
,
&
char_handle
),
test
::
IsErrorWithSubstr
(
"Unknown component 'missing'"
));
EXPECT_THAT
(
manager
.
LookupLayer
(
"baz"
,
"float"
,
&
dimension
,
&
char_handle
),
test
::
IsErrorWithSubstr
(
"Unknown component 'baz'"
));
// Try looking up valid components but unknown layers.
EXPECT_THAT
(
manager
.
LookupLayer
(
"foo"
,
"missing"
,
&
dimension
,
&
char_handle
),
test
::
IsErrorWithSubstr
(
"Unknown layer 'missing' in component 'foo'"
));
EXPECT_THAT
(
manager
.
LookupLayer
(
"bar"
,
"missing"
,
&
dimension
,
&
char_handle
),
test
::
IsErrorWithSubstr
(
"Unknown layer 'missing' in component 'bar'"
));
// Try looking up valid components and the names of layers or aliases in the
// other components.
EXPECT_THAT
(
manager
.
LookupLayer
(
"foo"
,
"uint16"
,
&
dimension
,
&
uint16_handle
),
test
::
IsErrorWithSubstr
(
"Unknown layer 'uint16' in component 'foo'"
));
EXPECT_THAT
(
manager
.
LookupLayer
(
"foo"
,
"uint16_alias"
,
&
dimension
,
&
uint16_handle
),
test
::
IsErrorWithSubstr
(
"Unknown layer 'uint16_alias' in component 'foo'"
));
EXPECT_THAT
(
manager
.
LookupLayer
(
"bar"
,
"char"
,
&
dimension
,
&
char_handle
),
test
::
IsErrorWithSubstr
(
"Unknown layer 'char' in component 'bar'"
));
EXPECT_THAT
(
manager
.
LookupLayer
(
"bar"
,
"char_alias"
,
&
dimension
,
&
char_handle
),
test
::
IsErrorWithSubstr
(
"Unknown layer 'char_alias' in component 'bar'"
));
// Look up layers with incorrect types.
EXPECT_THAT
(
manager
.
LookupLayer
(
"foo"
,
"char"
,
&
dimension
,
&
int16_handle
),
test
::
IsErrorWithSubstr
(
"Layer 'char' in component 'foo' does not match its expected type"
));
EXPECT_THAT
(
manager
.
LookupLayer
(
"foo"
,
"char"
,
&
dimension
,
&
uint16_handle
),
test
::
IsErrorWithSubstr
(
"Layer 'char' in component 'foo' does not match its expected type"
));
EXPECT_THAT
(
manager
.
LookupLayer
(
"foo"
,
"char"
,
&
dimension
,
&
pairwise_char_handle
),
test
::
IsErrorWithSubstr
(
"Layer 'char' in component 'foo' does not match "
"its expected OperandType"
));
EXPECT_THAT
(
manager
.
LookupLayer
(
"foo"
,
"int16"
,
&
dimension
,
&
char_handle
),
test
::
IsErrorWithSubstr
(
"Layer 'int16' in component 'foo' does not match its expected type"
));
EXPECT_THAT
(
manager
.
LookupLayer
(
"foo"
,
"int16"
,
&
dimension
,
&
uint16_handle
),
test
::
IsErrorWithSubstr
(
"Layer 'int16' in component 'foo' does not match its expected type"
));
EXPECT_THAT
(
manager
.
LookupLayer
(
"foo"
,
"int16"
,
&
dimension
,
&
pairwise_char_handle
),
test
::
IsErrorWithSubstr
(
"Layer 'int16' in component 'foo' does not match its expected type"
));
EXPECT_THAT
(
manager
.
LookupLayer
(
"bar"
,
"uint16"
,
&
dimension
,
&
char_handle
),
test
::
IsErrorWithSubstr
(
"Layer 'uint16' in component 'bar' does "
"not match its expected type"
));
EXPECT_THAT
(
manager
.
LookupLayer
(
"bar"
,
"uint16"
,
&
dimension
,
&
int16_handle
),
test
::
IsErrorWithSubstr
(
"Layer 'uint16' in component 'bar' does "
"not match its expected type"
));
EXPECT_THAT
(
manager
.
LookupLayer
(
"bar"
,
"uint16"
,
&
dimension
,
&
pairwise_char_handle
),
test
::
IsErrorWithSubstr
(
"Layer 'uint16' in component 'bar' does "
"not match its expected type"
));
EXPECT_THAT
(
manager
.
LookupLayer
(
"bar"
,
"pairwise_char"
,
&
dimension
,
&
char_handle
),
test
::
IsErrorWithSubstr
(
"Layer 'pairwise_char' in component 'bar' does "
"not match its expected OperandType"
));
EXPECT_THAT
(
manager
.
LookupLayer
(
"bar"
,
"pairwise_char"
,
&
dimension
,
&
int16_handle
),
test
::
IsErrorWithSubstr
(
"Layer 'pairwise_char' in component 'bar' does "
"not match its expected type"
));
EXPECT_THAT
(
manager
.
LookupLayer
(
"bar"
,
"pairwise_char"
,
&
dimension
,
&
uint16_handle
),
test
::
IsErrorWithSubstr
(
"Layer 'pairwise_char' in component 'bar' does "
"not match its expected type"
));
// Look up layers properly, and check their dimensions. Also verify that the
// looked-up handles are identical to the original handles.
LayerHandle
<
char
>
lookup_char_handle
;
LayerHandle
<
int16
>
lookup_int16_handle
;
LayerHandle
<
uint16
>
lookup_uint16_handle
;
PairwiseLayerHandle
<
char
>
lookup_pairwise_char_handle
;
TF_EXPECT_OK
(
manager
.
LookupLayer
(
"foo"
,
"char"
,
&
dimension
,
&
lookup_char_handle
));
EXPECT_EQ
(
dimension
,
5
);
ExpectBitwiseEqual
(
lookup_char_handle
,
char_handle
);
TF_EXPECT_OK
(
manager
.
LookupLayer
(
"foo"
,
"int16"
,
&
dimension
,
&
lookup_int16_handle
));
EXPECT_EQ
(
dimension
,
7
);
ExpectBitwiseEqual
(
lookup_int16_handle
,
int16_handle
);
TF_EXPECT_OK
(
manager
.
LookupLayer
(
"bar"
,
"uint16"
,
&
dimension
,
&
lookup_uint16_handle
));
EXPECT_EQ
(
dimension
,
11
);
ExpectBitwiseEqual
(
lookup_uint16_handle
,
uint16_handle
);
TF_EXPECT_OK
(
manager
.
LookupLayer
(
"bar"
,
"pairwise_char"
,
&
dimension
,
&
lookup_pairwise_char_handle
));
EXPECT_EQ
(
dimension
,
13
);
ExpectBitwiseEqual
(
lookup_pairwise_char_handle
,
pairwise_char_handle
);
}
// Tests that NetworkStates cannot start components without a manager.
TEST
(
NetworkStatesTest
,
NoManager
)
{
NetworkStates
network_states
;
EXPECT_THAT
(
network_states
.
StartNextComponent
(
10
),
test
::
IsErrorWithSubstr
(
"No manager"
));
}
// Tests that NetworkStates cannot start components when the manager is empty.
TEST
(
NetworkStatesTest
,
EmptyManager
)
{
NetworkStateManager
empty_manager
;
NetworkStates
network_states
;
network_states
.
Reset
(
&
empty_manager
);
EXPECT_THAT
(
network_states
.
StartNextComponent
(
10
),
test
::
IsErrorWithSubstr
(
"No next component"
));
}
// Tests that NetworkStates can start the same number of components as were
// configured in its manager.
TEST
(
NetworkStatesTest
,
StartNextComponent
)
{
NetworkStateManager
manager
;
TF_EXPECT_OK
(
manager
.
AddComponent
(
"foo"
));
TF_EXPECT_OK
(
manager
.
AddComponent
(
"bar"
));
TF_EXPECT_OK
(
manager
.
AddComponent
(
"baz"
));
NetworkStates
network_states
;
network_states
.
Reset
(
&
manager
);
TF_EXPECT_OK
(
network_states
.
StartNextComponent
(
10
));
TF_EXPECT_OK
(
network_states
.
StartNextComponent
(
11
));
TF_EXPECT_OK
(
network_states
.
StartNextComponent
(
12
));
EXPECT_THAT
(
network_states
.
StartNextComponent
(
13
),
test
::
IsErrorWithSubstr
(
"No next component"
));
}
// Tests that NetworkStates contains layers and locals whose dimensions match
// the configuration of its manager.
TEST
(
NetworkStatesTest
,
Dimensions
)
{
NetworkStateManager
manager
;
// The "foo" component has two layers and a local vector.
LayerHandle
<
float
>
foo_hidden_handle
;
LocalVectorHandle
<
int16
>
foo_local_handle
;
PairwiseLayerHandle
<
float
>
foo_logits_handle
;
TF_ASSERT_OK
(
manager
.
AddComponent
(
"foo"
));
TF_ASSERT_OK
(
manager
.
AddLayer
(
"hidden"
,
10
,
&
foo_hidden_handle
));
TF_ASSERT_OK
(
manager
.
AddLocal
(
20
,
&
foo_local_handle
));
TF_ASSERT_OK
(
manager
.
AddLayer
(
"logits"
,
30
,
&
foo_logits_handle
));
// The "bar" component has one layer and a local matrix.
LayerHandle
<
float
>
bar_logits_handle
;
LocalMatrixHandle
<
bool
>
bar_local_handle
;
TF_ASSERT_OK
(
manager
.
AddComponent
(
"bar"
));
TF_ASSERT_OK
(
manager
.
AddLayer
(
"logits"
,
40
,
&
bar_logits_handle
));
TF_ASSERT_OK
(
manager
.
AddLocal
(
50
,
&
bar_local_handle
));
// Initialize a NetworkStates and check its dimensions. Note that matrices
// start with 0 rows since there are 0 steps.
NetworkStates
network_states
;
network_states
.
Reset
(
&
manager
);
TF_EXPECT_OK
(
network_states
.
StartNextComponent
(
13
));
ExpectDimensions
(
network_states
.
GetLayer
(
foo_hidden_handle
),
0
,
10
);
EXPECT_EQ
(
network_states
.
GetLocal
(
foo_local_handle
).
size
(),
20
);
ExpectDimensions
(
network_states
.
GetLayer
(
foo_logits_handle
),
0
,
0
);
// Add some steps, and check that rows have been added to matrices, while
// vectors are unaffected.
network_states
.
AddSteps
(
19
);
ExpectDimensions
(
network_states
.
GetLayer
(
foo_hidden_handle
),
19
,
10
);
EXPECT_EQ
(
network_states
.
GetLocal
(
foo_local_handle
).
size
(),
20
);
ExpectDimensions
(
network_states
.
GetLayer
(
foo_logits_handle
),
19
,
19
*
30
);
// Again for the next component.
TF_EXPECT_OK
(
network_states
.
StartNextComponent
(
9
));
ExpectDimensions
(
network_states
.
GetLayer
(
bar_logits_handle
),
0
,
40
);
ExpectDimensions
(
network_states
.
GetLocal
(
bar_local_handle
),
0
,
50
);
// Add some steps, and check that rows have been added to matrices.
network_states
.
AddSteps
(
25
);
ExpectDimensions
(
network_states
.
GetLayer
(
bar_logits_handle
),
25
,
40
);
ExpectDimensions
(
network_states
.
GetLocal
(
bar_local_handle
),
25
,
50
);
EXPECT_THAT
(
network_states
.
StartNextComponent
(
10
),
test
::
IsErrorWithSubstr
(
"No next component"
));
// Check the layers of the first component. They should still have the same
// dimensions in spite of adding steps to the second component.
ExpectDimensions
(
network_states
.
GetLayer
(
foo_hidden_handle
),
19
,
10
);
ExpectDimensions
(
network_states
.
GetLayer
(
foo_logits_handle
),
19
,
19
*
30
);
}
// Tests that NetworkStates can be reused by resetting them repeatedly, possibly
// switching between different managers.
TEST
(
NetworkStatesTest
,
ResetWithDifferentManagers
)
{
std
::
vector
<
NetworkStateManager
>
managers
(
10
);
std
::
vector
<
LayerHandle
<
int
>>
layer_handles
(
10
);
std
::
vector
<
PairwiseLayerHandle
<
int
>>
pairwise_layer_handles
(
10
);
std
::
vector
<
LocalVectorHandle
<
int
>>
vector_handles
(
10
);
std
::
vector
<
LocalMatrixHandle
<
double
>>
matrix_handles
(
10
);
for
(
int
dim
=
0
;
dim
<
10
;
++
dim
)
{
TF_ASSERT_OK
(
managers
[
dim
].
AddComponent
(
"foo"
));
TF_ASSERT_OK
(
managers
[
dim
].
AddLayer
(
tensorflow
::
strings
::
StrCat
(
"layer"
,
dim
),
dim
,
&
layer_handles
[
dim
]));
TF_ASSERT_OK
(
managers
[
dim
].
AddLayer
(
tensorflow
::
strings
::
StrCat
(
"pairwise"
,
dim
),
dim
,
&
pairwise_layer_handles
[
dim
]));
TF_ASSERT_OK
(
managers
[
dim
].
AddLocal
(
dim
,
&
vector_handles
[
dim
]));
TF_ASSERT_OK
(
managers
[
dim
].
AddLocal
(
dim
,
&
matrix_handles
[
dim
]));
}
NetworkStates
network_states
;
for
(
int
trial
=
0
;
trial
<
10
;
++
trial
)
{
for
(
int
dim
=
0
;
dim
<
10
;
++
dim
)
{
network_states
.
Reset
(
&
managers
[
dim
]);
TF_ASSERT_OK
(
network_states
.
StartNextComponent
(
10
));
// Fill the vector local.
Fill
(
network_states
.
GetLocal
(
vector_handles
[
dim
]),
dim
,
100
*
trial
+
dim
);
// Check the vector local.
ExpectFilled
(
network_states
.
GetLocal
(
vector_handles
[
dim
]),
dim
,
100
*
trial
+
dim
);
// Repeatedly add a step and fill it with values.
for
(
int
step
=
0
;
step
<
100
;
++
step
)
{
network_states
.
AddStep
();
Fill
(
network_states
.
GetLayer
(
layer_handles
[
dim
]).
row
(
step
),
dim
,
1000
*
trial
+
100
*
dim
+
step
);
Fill
(
network_states
.
GetLocal
(
matrix_handles
[
dim
]).
row
(
step
),
dim
,
9876.0
*
trial
+
100
*
dim
+
step
);
}
// Check that data from earlier steps is preserved across reallocations.
for
(
int
step
=
0
;
step
<
100
;
++
step
)
{
ExpectFilled
(
network_states
.
GetLayer
(
layer_handles
[
dim
]).
row
(
step
),
dim
,
1000
*
trial
+
100
*
dim
+
step
);
ExpectFilled
(
network_states
.
GetLocal
(
matrix_handles
[
dim
]).
row
(
step
),
dim
,
9876.0
*
trial
+
100
*
dim
+
step
);
}
ExpectDimensions
(
network_states
.
GetLayer
(
pairwise_layer_handles
[
dim
]),
100
,
100
*
dim
);
}
}
}
// Tests that one NetworkStateManager can be shared simultaneously between
// multiple NetworkStates instances.
TEST
(
NetworkStatesTest
,
SharedManager
)
{
const
size_t
kDim
=
17
;
NetworkStateManager
manager
;
LayerHandle
<
int
>
layer_handle
;
PairwiseLayerHandle
<
int
>
pairwise_layer_handle
;
LocalVectorHandle
<
int
>
vector_handle
;
LocalMatrixHandle
<
double
>
matrix_handle
;
TF_ASSERT_OK
(
manager
.
AddComponent
(
"foo"
));
TF_ASSERT_OK
(
manager
.
AddLayer
(
"layer"
,
kDim
,
&
layer_handle
));
TF_ASSERT_OK
(
manager
.
AddLayer
(
"pairwise"
,
kDim
,
&
pairwise_layer_handle
));
TF_ASSERT_OK
(
manager
.
AddLocal
(
kDim
,
&
vector_handle
));
TF_ASSERT_OK
(
manager
.
AddLocal
(
kDim
,
&
matrix_handle
));
std
::
vector
<
NetworkStates
>
network_states_vec
(
10
);
for
(
NetworkStates
&
network_states
:
network_states_vec
)
{
network_states
.
Reset
(
&
manager
);
TF_ASSERT_OK
(
network_states
.
StartNextComponent
(
10
));
}
// Fill all vectors.
for
(
int
trial
=
0
;
trial
<
network_states_vec
.
size
();
++
trial
)
{
const
NetworkStates
&
network_states
=
network_states_vec
[
trial
];
Fill
(
network_states
.
GetLocal
(
vector_handle
),
kDim
,
3
*
trial
);
}
// Check all vectors.
for
(
int
trial
=
0
;
trial
<
network_states_vec
.
size
();
++
trial
)
{
const
NetworkStates
&
network_states
=
network_states_vec
[
trial
];
ExpectFilled
(
network_states
.
GetLocal
(
vector_handle
),
kDim
,
3
*
trial
);
}
// Fill all matrices. Interleave operations on the network states on each
// step, so all network states are "active" at the same time.
for
(
int
step
=
0
;
step
<
100
;
++
step
)
{
for
(
int
trial
=
0
;
trial
<
10
;
++
trial
)
{
NetworkStates
&
network_states
=
network_states_vec
[
trial
];
network_states
.
AddStep
();
Fill
(
network_states
.
GetLayer
(
layer_handle
).
row
(
step
),
kDim
,
999
*
trial
+
step
);
Fill
(
network_states
.
GetLocal
(
matrix_handle
).
row
(
step
),
kDim
,
1234.0
*
trial
+
step
);
ExpectDimensions
(
network_states
.
GetLayer
(
pairwise_layer_handle
),
step
+
1
,
kDim
*
(
step
+
1
));
}
}
// Check all matrices.
for
(
int
step
=
0
;
step
<
100
;
++
step
)
{
for
(
int
trial
=
0
;
trial
<
10
;
++
trial
)
{
const
NetworkStates
&
network_states
=
network_states_vec
[
trial
];
ExpectFilled
(
network_states
.
GetLayer
(
layer_handle
).
row
(
step
),
kDim
,
999
*
trial
+
step
);
ExpectFilled
(
network_states
.
GetLocal
(
matrix_handle
).
row
(
step
),
kDim
,
1234.0
*
trial
+
step
);
ExpectDimensions
(
network_states
.
GetLayer
(
pairwise_layer_handle
),
100
,
kDim
*
100
);
}
}
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/network_unit.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/network_unit.h"
#include <vector>
#include "tensorflow/core/lib/strings/str_util.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
string
NetworkUnit
::
GetClassName
(
const
ComponentSpec
&
component_spec
)
{
// The Python registration API is based on (relative) module paths, such as
// "some.module.FooNetwork". Therefore, we discard the module path prefix and
// use only the final segment, which is the subclass name.
const
std
::
vector
<
string
>
segments
=
tensorflow
::
str_util
::
Split
(
component_spec
.
network_unit
().
registered_name
(),
"."
);
CHECK_GT
(
segments
.
size
(),
0
)
<<
"No network unit name for component spec: "
<<
component_spec
.
ShortDebugString
();
return
segments
.
back
();
}
}
// namespace runtime
}
// namespace dragnn
REGISTER_SYNTAXNET_CLASS_REGISTRY
(
"DRAGNN Runtime Network Unit"
,
dragnn
::
runtime
::
NetworkUnit
);
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/network_unit.h
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_NETWORK_UNIT_H_
#define DRAGNN_RUNTIME_NETWORK_UNIT_H_
#include <stddef.h>
#include <string>
#include "dragnn/core/compute_session.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "syntaxnet/registry.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// Interface for network units for sequential inference.
class
NetworkUnit
:
public
RegisterableClass
<
NetworkUnit
>
{
public:
NetworkUnit
(
const
NetworkUnit
&
that
)
=
delete
;
NetworkUnit
&
operator
=
(
const
NetworkUnit
&
that
)
=
delete
;
virtual
~
NetworkUnit
()
=
default
;
// Returns the network unit class name specified in the |component_spec|.
static
string
GetClassName
(
const
ComponentSpec
&
component_spec
);
// Initializes this to the configuration in the |component_spec|. Retrieves
// pre-trained variables from the |variable_store|, which must outlive this.
// Adds layers and local operands to the |network_state_manager|, which must
// be positioned at the current component. Requests SessionState extensions
// from the |extension_manager|. On error, returns non-OK.
virtual
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
=
0
;
// Returns the name of the layer that contains classification logits, or an
// empty string if this does not produce logits. Requires that Initialize()
// was called.
virtual
string
GetLogitsName
()
const
=
0
;
// Evaluates this network unit on the |session_state| and |compute_session|.
// Requires that:
// * The network states in the |session_state| is positioned at the current
// component, which must have at least |step_index|+1 steps.
// * The same component in the |compute_session| must have traversed
// |step_index| transitions.
// * Initialize() was called.
// On error, returns non-OK.
virtual
tensorflow
::
Status
Evaluate
(
size_t
step_index
,
SessionState
*
session_state
,
ComputeSession
*
compute_session
)
const
=
0
;
protected:
NetworkUnit
()
=
default
;
private:
// Helps prevent use of the Create() method; use CreateOrError() instead.
using
RegisterableClass
<
NetworkUnit
>::
Create
;
};
}
// namespace runtime
}
// namespace dragnn
DECLARE_SYNTAXNET_CLASS_REGISTRY
(
"DRAGNN Runtime Network Unit"
,
dragnn
::
runtime
::
NetworkUnit
);
}
// namespace syntaxnet
// Registers a subclass using its class name as a string.
#define DRAGNN_RUNTIME_REGISTER_NETWORK_UNIT(subclass) \
REGISTER_SYNTAXNET_CLASS_COMPONENT( \
::syntaxnet::dragnn::runtime::NetworkUnit, #subclass, subclass)
#endif // DRAGNN_RUNTIME_NETWORK_UNIT_H_
research/syntaxnet/dragnn/runtime/network_unit_base.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/network_unit_base.h"
#include <string.h>
#include "tensorflow/core/lib/core/errors.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Returns the sum of the dimensions of all channels in the |manager|. The
// EmbeddingManager template type should be either FixedEmbeddingManager or
// LinkedEmbeddingManager; note that both share the same API.
template
<
class
EmbeddingManager
>
size_t
SumEmbeddingDimensions
(
const
EmbeddingManager
&
manager
)
{
size_t
sum
=
0
;
for
(
size_t
i
=
0
;
i
<
manager
.
num_channels
();
++
i
)
{
sum
+=
manager
.
embedding_dim
(
i
);
}
return
sum
;
}
// Copies each channel of the |embeddings| into the region starting at |data|.
// Returns a pointer to one past the last element of the copied region. The
// Embeddings type should be FixedEmbeddings or LinkedEmbeddings; note that both
// have the same API.
//
// TODO(googleuser): Try a vectorized copy instead of memcpy(). Unclear whether
// we can do better, though. For one, the memcpy() implementation may already
// be vectorized. Also, while the input embeddings are aligned, the output is
// not; e.g., consider concatenating inputs with dims 7 and 9. This could be
// addressed by requiring that embedding dims are aligned, or by handling the
// unaligned prefix separately.
//
// TODO(googleuser): Consider alternatives for handling fixed feature channels
// with size>1. The least surprising approach is to concatenate the size>1
// embeddings inside FixedEmbeddings, so the channel IDs still correspond to
// positions in the ComponentSpec.fixed_feature list. However, that means the
// same embedding gets copied twice, once there and once here. Conversely, we
// could split the size>1 embeddings into separate channels, eliding a copy
// while obfuscating the channel IDs. IMO, separate channels seem better
// because very few bits of DRAGNN actually access individual channels, and I
// wrote many of those bits.
template
<
class
Embeddings
>
float
*
CopyEmbeddings
(
const
Embeddings
&
embeddings
,
float
*
data
)
{
for
(
size_t
i
=
0
;
i
<
embeddings
.
num_embeddings
();
++
i
)
{
const
Vector
<
float
>
vector
=
embeddings
.
embedding
(
i
);
memcpy
(
data
,
vector
.
data
(),
vector
.
size
()
*
sizeof
(
float
));
data
+=
vector
.
size
();
}
return
data
;
}
}
// namespace
tensorflow
::
Status
NetworkUnitBase
::
InitializeBase
(
bool
use_concatenated_input
,
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
{
use_concatenated_input_
=
use_concatenated_input
;
num_actions_
=
component_spec
.
num_actions
();
TF_RETURN_IF_ERROR
(
fixed_embedding_manager_
.
Reset
(
component_spec
,
variable_store
,
network_state_manager
));
TF_RETURN_IF_ERROR
(
linked_embedding_manager_
.
Reset
(
component_spec
,
variable_store
,
network_state_manager
));
concatenated_input_dim_
=
SumEmbeddingDimensions
(
fixed_embedding_manager_
)
+
SumEmbeddingDimensions
(
linked_embedding_manager_
);
if
(
use_concatenated_input_
)
{
// If there is <= 1 input embedding, then the concatenation is trivial and
// we don't need a local vector; see ConcatenateInput().
const
size_t
num_embeddings
=
fixed_embedding_manager_
.
num_embeddings
()
+
linked_embedding_manager_
.
num_embeddings
();
if
(
num_embeddings
>
1
)
{
TF_RETURN_IF_ERROR
(
network_state_manager
->
AddLocal
(
concatenated_input_dim_
,
&
concatenated_input_handle_
));
}
// Check that all fixed features are embedded.
for
(
size_t
i
=
0
;
i
<
fixed_embedding_manager_
.
num_channels
();
++
i
)
{
if
(
!
fixed_embedding_manager_
.
is_embedded
(
i
))
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Non-embedded fixed features cannot be concatenated"
);
}
}
}
extension_manager
->
GetShared
(
&
fixed_embeddings_handle_
);
extension_manager
->
GetShared
(
&
linked_embeddings_handle_
);
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
NetworkUnitBase
::
EvaluateBase
(
SessionState
*
session_state
,
ComputeSession
*
compute_session
,
Vector
<
float
>
*
concatenated_input
)
const
{
FixedEmbeddings
&
fixed_embeddings
=
session_state
->
extensions
.
Get
(
fixed_embeddings_handle_
);
LinkedEmbeddings
&
linked_embeddings
=
session_state
->
extensions
.
Get
(
linked_embeddings_handle_
);
TF_RETURN_IF_ERROR
(
fixed_embeddings
.
Reset
(
&
fixed_embedding_manager_
,
session_state
->
network_states
,
compute_session
));
TF_RETURN_IF_ERROR
(
linked_embeddings
.
Reset
(
&
linked_embedding_manager_
,
session_state
->
network_states
,
compute_session
));
if
(
use_concatenated_input_
&&
concatenated_input
!=
nullptr
)
{
*
concatenated_input
=
ConcatenateInput
(
session_state
);
}
return
tensorflow
::
Status
::
OK
();
}
Vector
<
float
>
NetworkUnitBase
::
ConcatenateInput
(
SessionState
*
session_state
)
const
{
DCHECK
(
use_concatenated_input_
);
const
FixedEmbeddings
&
fixed_embeddings
=
session_state
->
extensions
.
Get
(
fixed_embeddings_handle_
);
const
LinkedEmbeddings
&
linked_embeddings
=
session_state
->
extensions
.
Get
(
linked_embeddings_handle_
);
const
size_t
num_embeddings
=
fixed_embeddings
.
num_embeddings
()
+
linked_embeddings
.
num_embeddings
();
// Special cases where no actual concatenation is required.
if
(
num_embeddings
==
0
)
return
{};
if
(
num_embeddings
==
1
)
{
return
fixed_embeddings
.
num_embeddings
()
>
0
?
fixed_embeddings
.
embedding
(
0
)
:
linked_embeddings
.
embedding
(
0
);
}
// General case; concatenate into a local vector. The ordering of embeddings
// must be exactly the same as in the Python codebase, which is:
// 1. Fixed embeddings before linked embeddings (see get_input_tensor() in
// network_units.py).
// 2. In each type, ordered as listed in ComponentSpec.fixed/linked_feature
// (see DynamicComponentBuilder._feedforward_unit() in component.py).
//
// Since FixedEmbeddings and LinkedEmbeddings already follow the order defined
// in the ComponentSpec, it suffices to append each fixed embedding, then each
// linked embedding.
const
MutableVector
<
float
>
concatenation
=
session_state
->
network_states
.
GetLocal
(
concatenated_input_handle_
);
float
*
data
=
concatenation
.
data
();
data
=
CopyEmbeddings
(
fixed_embeddings
,
data
);
data
=
CopyEmbeddings
(
linked_embeddings
,
data
);
DCHECK_EQ
(
data
,
concatenation
.
end
());
return
Vector
<
float
>
(
concatenation
);
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/network_unit_base.h
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_NETWORK_UNIT_BASE_H_
#define DRAGNN_RUNTIME_NETWORK_UNIT_BASE_H_
#include <stddef.h>
#include <utility>
#include <vector>
#include "dragnn/core/compute_session.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/fixed_embeddings.h"
#include "dragnn/runtime/linked_embeddings.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/network_unit.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/variable_store.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// A base class for network units that provides common functionality, analogous
// to NetworkUnitInterface.__init__() in network_units.py. Specifically, this
// class manages and builds input embeddings and, as an convenience, optionally
// concatenates the input embeddings into a single vector.
//
// Since recurrent layers are both outputs and inputs, they complicate network
// unit initialization. In particular, the linked embeddings cannot be set up
// until the charateristics of all recurrently-accessible layers are known. On
// the other hand, some layers cannot be initialized until all inputs, including
// the linked embeddings, are set up. For example, the IdentityNetwork outputs
// a layer whose dimension is the sum of all input dimensions.
//
// To accommodate recurrent layers, network unit initialization is organized
// into three phases:
// 1. (Subclass) Initialize all recurrently-accessible layers.
// 2. (This class) Initialize embedding managers and other common state.
// 3. (Subclass) Initialize any non-recurrent layers.
//
// Concretely, the subclass's Initialize() should first add recurrent layers,
// then call InitializeBase(), and finally finish initializing. Evaluation is
// simpler: the subclass's Evaluate() may call EvaluateBase() at any time.
//
// Note: Network unit initialization is similarly interleaved between base and
// subclasses in the Python codebase; see NetworkUnitInterface.get_layer_size()
// and the "init_layers" argument to NetworkUnitInterface.__init__().
class
NetworkUnitBase
:
public
NetworkUnit
{
public:
// Initializes common state as configured in the |component_spec|. Retrieves
// pre-trained embedding matrices from the |variable_store|. Looks up linked
// embeddings in the |network_state_manager|, which must contain all recurrent
// layers. Requests any required extensions from the |extension_manager|. If
// |use_concatenated_input| is true, prepares to concatenate input embeddings
// in EvaluateBase(). On error, returns non-OK.
tensorflow
::
Status
InitializeBase
(
bool
use_concatenated_input
,
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
);
// Resets the fixed and linked embeddings in the |session_state| using its
// network states and the |compute_session|. Requires that InitializeBase()
// was called. If this was prepared for concatenation (see InitializeBase())
// and if |concatenated_input| is non-null, points it at the concatenation of
// the fixed and linked embeddings. Otherwise, no concatenation occurs. On
// error, returns non-OK.
tensorflow
::
Status
EvaluateBase
(
SessionState
*
session_state
,
ComputeSession
*
compute_session
,
Vector
<
float
>
*
concatenated_input
)
const
;
// Accessors. All require that InitializeBase() was called.
const
FixedEmbeddingManager
&
fixed_embedding_manager
()
const
;
const
LinkedEmbeddingManager
&
linked_embedding_manager
()
const
;
size_t
num_actions
()
const
{
return
num_actions_
;
}
size_t
concatenated_input_dim
()
const
{
return
concatenated_input_dim_
;
}
private:
// Returns the concatenation of the fixed and linked embeddings in the
// |seesion_state|. Requires that |use_concatenated_input_| is true.
Vector
<
float
>
ConcatenateInput
(
SessionState
*
session_state
)
const
;
// Managers for fixed and linked embeddings in this component.
FixedEmbeddingManager
fixed_embedding_manager_
;
LinkedEmbeddingManager
linked_embedding_manager_
;
// Fixed and linked embeddings.
SharedExtensionHandle
<
FixedEmbeddings
>
fixed_embeddings_handle_
;
SharedExtensionHandle
<
LinkedEmbeddings
>
linked_embeddings_handle_
;
// Number of actions supported by the transition system.
size_t
num_actions_
=
0
;
// Sum of dimensions of all fixed and linked embeddings.
size_t
concatenated_input_dim_
=
0
;
// Whether to concatenate the input embeddings.
bool
use_concatenated_input_
=
false
;
// Handle of the vector that holds the concatenated input, or invalid if no
// concatenation is required.
LocalVectorHandle
<
float
>
concatenated_input_handle_
;
};
// Implementation details below.
inline
const
FixedEmbeddingManager
&
NetworkUnitBase
::
fixed_embedding_manager
()
const
{
return
fixed_embedding_manager_
;
}
inline
const
LinkedEmbeddingManager
&
NetworkUnitBase
::
linked_embedding_manager
()
const
{
return
linked_embedding_manager_
;
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_NETWORK_UNIT_BASE_H_
research/syntaxnet/dragnn/runtime/network_unit_base_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/network_unit_base.h"
#include <stddef.h>
#include <memory>
#include <string>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/fixed_embeddings.h"
#include "dragnn/runtime/linked_embeddings.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include <gmock/gmock.h>
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
using
::
testing
::
_
;
using
::
testing
::
Invoke
;
using
::
testing
::
Return
;
// Dimensions of the layers in the network.
static
constexpr
size_t
kPreviousDim
=
77
;
static
constexpr
size_t
kRecurrentDim
=
123
;
// Contents of the layers in the network.
static
constexpr
float
kPreviousValue
=
-
2.75
;
static
constexpr
float
kRecurrentValue
=
6.25
;
// Number of steps taken in each component.
static
constexpr
size_t
kNumSteps
=
10
;
// A trivial network unit that exposes the concatenated inputs. Note that
// NetworkUnitBase does not override the interface methods, so we need a
// concrete subclass for testing.
class
FooNetwork
:
public
NetworkUnitBase
{
public:
void
RequestConcatenation
()
{
request_concatenation_
=
true
;
}
void
ProvideConcatenatedInput
()
{
provide_concatenated_input_
=
true
;
}
Vector
<
float
>
concatenated_input
()
const
{
return
concatenated_input_
;
}
// Implements NetworkUnit.
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
override
{
TF_RETURN_IF_ERROR
(
network_state_manager
->
AddLayer
(
"recurrent_layer"
,
kRecurrentDim
,
&
recurrent_handle_
));
return
InitializeBase
(
request_concatenation_
,
component_spec
,
variable_store
,
network_state_manager
,
extension_manager
);
}
string
GetLogitsName
()
const
override
{
return
""
;
}
tensorflow
::
Status
Evaluate
(
size_t
unused_step_index
,
SessionState
*
session_state
,
ComputeSession
*
compute_session
)
const
override
{
return
EvaluateBase
(
session_state
,
compute_session
,
provide_concatenated_input_
?
&
concatenated_input_
:
nullptr
);
}
private:
bool
request_concatenation_
=
false
;
bool
provide_concatenated_input_
=
false
;
LayerHandle
<
float
>
recurrent_handle_
;
mutable
Vector
<
float
>
concatenated_input_
;
// Evaluate() sets this
};
class
NetworkUnitBaseTest
:
public
NetworkTestBase
{
protected:
// Initializes the |network_unit_| based on the |component_spec_text| and
// evaluates it. On error, returns non-OK.
tensorflow
::
Status
Run
(
const
string
&
component_spec_text
)
{
ComponentSpec
component_spec
;
CHECK
(
TextFormat
::
ParseFromString
(
component_spec_text
,
&
component_spec
));
component_spec
.
set_name
(
kTestComponentName
);
AddComponent
(
"previous_component"
);
AddLayer
(
"previous_layer"
,
kPreviousDim
);
AddComponent
(
kTestComponentName
);
TF_RETURN_IF_ERROR
(
network_unit_
.
Initialize
(
component_spec
,
&
variable_store_
,
&
network_state_manager_
,
&
extension_manager_
));
// Create and populate the network states.
network_states_
.
Reset
(
&
network_state_manager_
);
StartComponent
(
kNumSteps
);
StartComponent
(
kNumSteps
);
FillLayer
(
"previous_component"
,
"previous_layer"
,
kPreviousValue
);
FillLayer
(
kTestComponentName
,
"recurrent_layer"
,
kRecurrentValue
);
session_state_
.
extensions
.
Reset
(
&
extension_manager_
);
// Neither FooNetwork nor NetworkUnitBase look at the step index, so use an
// arbitrary value.
return
network_unit_
.
Evaluate
(
0
,
&
session_state_
,
&
compute_session_
);
}
FooNetwork
network_unit_
;
std
::
vector
<
std
::
vector
<
float
>>
concatenated_inputs_
;
};
// Tests that NetworkUnitBase produces an empty vector when concatenating and
// there are no input embeddings.
TEST_F
(
NetworkUnitBaseTest
,
ConcatenateNoInputs
)
{
network_unit_
.
RequestConcatenation
();
network_unit_
.
ProvideConcatenatedInput
();
TF_ASSERT_OK
(
Run
(
""
));
EXPECT_EQ
(
network_unit_
.
fixed_embedding_manager
().
num_channels
(),
0
);
EXPECT_EQ
(
network_unit_
.
linked_embedding_manager
().
num_channels
(),
0
);
EXPECT_EQ
(
network_unit_
.
num_actions
(),
0
);
EXPECT_EQ
(
network_unit_
.
concatenated_input_dim
(),
0
);
EXPECT_TRUE
(
network_unit_
.
concatenated_input
().
empty
());
}
// Tests that NetworkUnitBase produces a copy of the single input embedding when
// concatenating a single fixed channel.
TEST_F
(
NetworkUnitBaseTest
,
ConcatenateOneFixedChannel
)
{
const
float
kEmbedding
=
1.5
;
const
float
kFeature
=
0.5
;
const
size_t
kDim
=
13
;
const
string
kSpec
=
R"(num_actions: 42
fixed_feature {
vocabulary_size: 11
embedding_dim: 13
size: 1
})"
;
AddFixedEmbeddingMatrix
(
0
,
11
,
kDim
,
kEmbedding
);
EXPECT_CALL
(
compute_session_
,
GetInputFeatures
(
_
,
_
,
_
,
_
,
_
))
.
WillOnce
(
Invoke
(
ExtractFeatures
(
0
,
{{
1
,
kFeature
}})));
const
float
kValue
=
kEmbedding
*
kFeature
;
network_unit_
.
RequestConcatenation
();
network_unit_
.
ProvideConcatenatedInput
();
TF_ASSERT_OK
(
Run
(
kSpec
));
EXPECT_EQ
(
network_unit_
.
fixed_embedding_manager
().
num_channels
(),
1
);
EXPECT_EQ
(
network_unit_
.
linked_embedding_manager
().
num_channels
(),
0
);
EXPECT_EQ
(
network_unit_
.
num_actions
(),
42
);
EXPECT_EQ
(
network_unit_
.
concatenated_input_dim
(),
kDim
);
ExpectVector
(
network_unit_
.
concatenated_input
(),
network_unit_
.
concatenated_input_dim
(),
kValue
);
}
// Tests that NetworkUnitBase does not concatenate if concatenation is requested
// and the concatenated input vector is not provided.
TEST_F
(
NetworkUnitBaseTest
,
ConcatenatedInputVectorNotProvided
)
{
const
float
kEmbedding
=
1.5
;
const
float
kFeature
=
0.5
;
const
size_t
kDim
=
13
;
const
string
kSpec
=
R"(num_actions: 37
fixed_feature {
vocabulary_size: 11
embedding_dim: 13
size: 1
})"
;
AddFixedEmbeddingMatrix
(
0
,
11
,
kDim
,
kEmbedding
);
EXPECT_CALL
(
compute_session_
,
GetInputFeatures
(
_
,
_
,
_
,
_
,
_
))
.
WillOnce
(
Invoke
(
ExtractFeatures
(
0
,
{{
1
,
kFeature
}})));
network_unit_
.
RequestConcatenation
();
TF_ASSERT_OK
(
Run
(
kSpec
));
// Embedding managers and other config is set up properly.
EXPECT_EQ
(
network_unit_
.
fixed_embedding_manager
().
num_channels
(),
1
);
EXPECT_EQ
(
network_unit_
.
linked_embedding_manager
().
num_channels
(),
0
);
EXPECT_EQ
(
network_unit_
.
num_actions
(),
37
);
EXPECT_EQ
(
network_unit_
.
concatenated_input_dim
(),
kDim
);
// But the concatenation was not performed.
EXPECT_TRUE
(
network_unit_
.
concatenated_input
().
empty
());
}
// As above, but with the converse condition: does not request concatenation,
// but does provide the concatenated input vector.
TEST_F
(
NetworkUnitBaseTest
,
ConcatenationNotRequested
)
{
const
float
kEmbedding
=
1.5
;
const
float
kFeature
=
0.5
;
const
size_t
kDim
=
13
;
const
string
kSpec
=
R"(num_actions: 31
fixed_feature {
vocabulary_size: 11
embedding_dim: 13
size: 1
})"
;
AddFixedEmbeddingMatrix
(
0
,
11
,
kDim
,
kEmbedding
);
EXPECT_CALL
(
compute_session_
,
GetInputFeatures
(
_
,
_
,
_
,
_
,
_
))
.
WillOnce
(
Invoke
(
ExtractFeatures
(
0
,
{{
1
,
kFeature
}})));
network_unit_
.
ProvideConcatenatedInput
();
TF_ASSERT_OK
(
Run
(
kSpec
));
// Embedding managers and other config is set up properly.
EXPECT_EQ
(
network_unit_
.
fixed_embedding_manager
().
num_channels
(),
1
);
EXPECT_EQ
(
network_unit_
.
linked_embedding_manager
().
num_channels
(),
0
);
EXPECT_EQ
(
network_unit_
.
num_actions
(),
31
);
EXPECT_EQ
(
network_unit_
.
concatenated_input_dim
(),
kDim
);
// But the concatenation was not performed.
EXPECT_TRUE
(
network_unit_
.
concatenated_input
().
empty
());
}
// Tests that NetworkUnitBase produces a copy of the single input embedding when
// concatenating a single linked channel.
TEST_F
(
NetworkUnitBaseTest
,
ConcatenateOneLinkedChannel
)
{
const
string
kSpec
=
R"(num_actions: 37
linked_feature {
embedding_dim: -1
source_component: 'previous_component'
source_layer: 'previous_layer'
size: 1
})"
;
EXPECT_CALL
(
compute_session_
,
GetTranslatedLinkFeatures
(
_
,
_
))
.
WillOnce
(
Invoke
(
ExtractLinks
(
0
,
{
"step_idx: 5"
})));
EXPECT_CALL
(
compute_session_
,
SourceComponentBeamSize
(
_
,
_
))
.
WillRepeatedly
(
Return
(
1
));
network_unit_
.
RequestConcatenation
();
network_unit_
.
ProvideConcatenatedInput
();
TF_ASSERT_OK
(
Run
(
kSpec
));
EXPECT_EQ
(
network_unit_
.
fixed_embedding_manager
().
num_channels
(),
0
);
EXPECT_EQ
(
network_unit_
.
linked_embedding_manager
().
num_channels
(),
1
);
EXPECT_EQ
(
network_unit_
.
num_actions
(),
37
);
EXPECT_EQ
(
network_unit_
.
concatenated_input_dim
(),
kPreviousDim
);
ExpectVector
(
network_unit_
.
concatenated_input
(),
network_unit_
.
concatenated_input_dim
(),
kPreviousValue
);
}
// Tests that NetworkUnitBase concatenates a fixed and linked channel in that
// order.
TEST_F
(
NetworkUnitBaseTest
,
ConcatenateOneChannelOfEachType
)
{
const
float
kEmbedding
=
1.25
;
const
float
kFeature
=
0.75
;
const
size_t
kFixedDim
=
13
;
const
string
kSpec
=
R"(num_actions: 77
fixed_feature {
vocabulary_size: 11
embedding_dim: 13
size: 1
}
linked_feature {
embedding_dim: -1
source_component: 'previous_component'
source_layer: 'previous_layer'
size: 1
})"
;
AddFixedEmbeddingMatrix
(
0
,
11
,
kFixedDim
,
kEmbedding
);
EXPECT_CALL
(
compute_session_
,
GetInputFeatures
(
_
,
_
,
_
,
_
,
_
))
.
WillOnce
(
Invoke
(
ExtractFeatures
(
0
,
{{
1
,
kFeature
}})));
const
float
kFixedValue
=
kEmbedding
*
kFeature
;
EXPECT_CALL
(
compute_session_
,
GetTranslatedLinkFeatures
(
_
,
_
))
.
WillOnce
(
Invoke
(
ExtractLinks
(
0
,
{
"step_idx: 5"
})));
EXPECT_CALL
(
compute_session_
,
SourceComponentBeamSize
(
_
,
_
))
.
WillRepeatedly
(
Return
(
1
));
network_unit_
.
RequestConcatenation
();
network_unit_
.
ProvideConcatenatedInput
();
TF_ASSERT_OK
(
Run
(
kSpec
));
EXPECT_EQ
(
network_unit_
.
fixed_embedding_manager
().
num_channels
(),
1
);
EXPECT_EQ
(
network_unit_
.
linked_embedding_manager
().
num_channels
(),
1
);
EXPECT_EQ
(
network_unit_
.
num_actions
(),
77
);
EXPECT_EQ
(
network_unit_
.
concatenated_input_dim
(),
kFixedDim
+
kPreviousDim
);
// Check that each sub-segment is equal to one of the input embeddings.
const
Vector
<
float
>
input
=
network_unit_
.
concatenated_input
();
EXPECT_EQ
(
input
.
size
(),
network_unit_
.
concatenated_input_dim
());
size_t
index
=
0
;
size_t
end
=
kFixedDim
;
for
(;
index
<
end
;
++
index
)
EXPECT_EQ
(
input
[
index
],
kFixedValue
);
end
+=
kPreviousDim
;
for
(;
index
<
end
;
++
index
)
EXPECT_EQ
(
input
[
index
],
kPreviousValue
);
}
// Tests that NetworkUnitBase produces a properly-ordered concatenation of
// multiple fixed and linked channels, including a recurrent channel.
TEST_F
(
NetworkUnitBaseTest
,
ConcatenateMultipleChannelsOfEachType
)
{
const
float
kEmbedding0
=
1.25
;
const
float
kEmbedding1
=
-
0.125
;
const
float
kFeature0
=
0.75
;
const
float
kFeature1
=
-
2.5
;
const
size_t
kFixedDim0
=
13
;
const
size_t
kFixedDim1
=
19
;
const
string
kSpec
=
R"(num_actions: 99
fixed_feature {
vocabulary_size: 11
embedding_dim: 13
size: 1
}
fixed_feature {
vocabulary_size: 17
embedding_dim: 19
size: 1
}
linked_feature {
embedding_dim: -1
source_component: 'previous_component'
source_layer: 'previous_layer'
size: 1
}
linked_feature {
embedding_dim: -1
source_component: 'test_component'
source_layer: 'recurrent_layer'
size: 1
})"
;
AddFixedEmbeddingMatrix
(
0
,
11
,
kFixedDim0
,
kEmbedding0
);
AddFixedEmbeddingMatrix
(
1
,
17
,
kFixedDim1
,
kEmbedding1
);
EXPECT_CALL
(
compute_session_
,
GetInputFeatures
(
_
,
_
,
_
,
_
,
_
))
.
WillOnce
(
Invoke
(
ExtractFeatures
(
0
,
{{
1
,
kFeature0
}})))
.
WillOnce
(
Invoke
(
ExtractFeatures
(
1
,
{{
1
,
kFeature1
}})));
const
float
kFixedValue0
=
kEmbedding0
*
kFeature0
;
const
float
kFixedValue1
=
kEmbedding1
*
kFeature1
;
EXPECT_CALL
(
compute_session_
,
GetTranslatedLinkFeatures
(
_
,
_
))
.
WillOnce
(
Invoke
(
ExtractLinks
(
0
,
{
"step_idx: 5"
})))
.
WillOnce
(
Invoke
(
ExtractLinks
(
1
,
{
"step_idx: 6"
})));
EXPECT_CALL
(
compute_session_
,
SourceComponentBeamSize
(
_
,
_
))
.
WillRepeatedly
(
Return
(
1
));
network_unit_
.
RequestConcatenation
();
network_unit_
.
ProvideConcatenatedInput
();
TF_ASSERT_OK
(
Run
(
kSpec
));
EXPECT_EQ
(
network_unit_
.
fixed_embedding_manager
().
num_channels
(),
2
);
EXPECT_EQ
(
network_unit_
.
linked_embedding_manager
().
num_channels
(),
2
);
EXPECT_EQ
(
network_unit_
.
num_actions
(),
99
);
EXPECT_EQ
(
network_unit_
.
concatenated_input_dim
(),
kFixedDim0
+
kFixedDim1
+
kPreviousDim
+
kRecurrentDim
);
// Check that each sub-segment is equal to one of the input embeddings. For
// compatibility with the Python codebase, fixed channels must appear before
// linked channels, and among each type order follows the ComponentSpec.
const
Vector
<
float
>
input
=
network_unit_
.
concatenated_input
();
EXPECT_EQ
(
input
.
size
(),
network_unit_
.
concatenated_input_dim
());
size_t
index
=
0
;
size_t
end
=
kFixedDim0
;
for
(;
index
<
end
;
++
index
)
EXPECT_EQ
(
input
[
index
],
kFixedValue0
);
end
+=
kFixedDim1
;
for
(;
index
<
end
;
++
index
)
EXPECT_EQ
(
input
[
index
],
kFixedValue1
);
end
+=
kPreviousDim
;
for
(;
index
<
end
;
++
index
)
EXPECT_EQ
(
input
[
index
],
kPreviousValue
);
end
+=
kRecurrentDim
;
for
(;
index
<
end
;
++
index
)
EXPECT_EQ
(
input
[
index
],
kRecurrentValue
);
}
// Tests that NetworkUnitBase refuses to concatenate if there are non-embedded
// fixed embeddings.
TEST_F
(
NetworkUnitBaseTest
,
CannotConcatenateNonEmbeddedFixedFeatures
)
{
const
string
kBadSpec
=
R"(fixed_feature {
embedding_dim: -1
size: 1
})"
;
network_unit_
.
RequestConcatenation
();
network_unit_
.
ProvideConcatenatedInput
();
EXPECT_THAT
(
Run
(
kBadSpec
),
test
::
IsErrorWithSubstr
(
"Non-embedded fixed features cannot be concatenated"
));
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/network_unit_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/network_unit.h"
#include <memory>
#include <string>
#include "dragnn/core/compute_session.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Expects that the two pointers have the same address.
void
ExpectSameAddress
(
const
void
*
pointer1
,
const
void
*
pointer2
)
{
EXPECT_EQ
(
pointer1
,
pointer2
);
}
// A trivial implementation for tests.
class
FooNetwork
:
public
NetworkUnit
{
public:
// Implements NetworkUnit.
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
override
{
return
tensorflow
::
Status
::
OK
();
}
string
GetLogitsName
()
const
override
{
return
"foo_logits"
;
}
tensorflow
::
Status
Evaluate
(
size_t
step_index
,
SessionState
*
session_state
,
ComputeSession
*
compute_session
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_NETWORK_UNIT
(
FooNetwork
);
// Tests that a human-friendly error is produced for empty network units.
TEST
(
NetworkUnitTest
,
GetClassNameDegenerateName
)
{
ComponentSpec
component_spec
;
EXPECT_DEATH
(
NetworkUnit
::
GetClassName
(
component_spec
),
"No network unit name for component spec"
);
}
// Tests that NetworkUnit::GetClassName() resolves names properly.
TEST
(
NetworkUnitTest
,
GetClassName
)
{
for
(
const
string
&
registered_name
:
{
"FooNetwork"
,
"module.FooNetwork"
,
"some.long.path.to.module.FooNetwork"
})
{
ComponentSpec
component_spec
;
component_spec
.
mutable_network_unit
()
->
set_registered_name
(
registered_name
);
EXPECT_EQ
(
NetworkUnit
::
GetClassName
(
component_spec
),
"FooNetwork"
);
}
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/operands.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/operands.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
OperandHandle
OperandManager
::
Add
(
const
OperandSpec
&
spec
)
{
const
size_t
index
=
specs_
.
size
();
specs_
.
push_back
(
spec
);
switch
(
spec
.
type
)
{
case
OperandType
::
kSingular
:
handle_index_to_typed_index_
.
push_back
(
singular_spans_
.
size
());
singular_spans_
.
emplace_back
(
singular_size_
,
spec
.
size
);
singular_size_
+=
PadToAlignment
(
spec
.
size
);
break
;
case
OperandType
::
kStepwise
:
handle_index_to_typed_index_
.
push_back
(
stepwise_spans_
.
size
());
stepwise_spans_
.
emplace_back
(
stepwise_stride_
,
spec
.
size
);
stepwise_stride_
+=
PadToAlignment
(
spec
.
size
);
break
;
case
OperandType
::
kPairwise
:
handle_index_to_typed_index_
.
push_back
(
pairwise_sizes_
.
size
());
pairwise_sizes_
.
push_back
(
spec
.
size
);
break
;
}
return
OperandHandle
(
index
);
}
void
Operands
::
Reset
(
const
OperandManager
*
manager
,
size_t
pre_allocate_num_steps
)
{
manager_
=
manager
;
handle_index_to_typed_index_
=
manager_
->
handle_index_to_typed_index_
;
stepwise_spans_
=
manager_
->
stepwise_spans_
;
stepwise_stride_
=
manager_
->
stepwise_stride_
;
pairwise_sizes_
=
manager_
->
pairwise_sizes_
;
// Allocate and parcel out singular operands.
singular_operands_
.
clear
();
singular_operands_
.
reserve
(
manager_
->
singular_spans_
.
size
());
singular_array_
.
Reserve
(
manager_
->
singular_size_
);
char
*
data
=
singular_array_
.
view
().
data
();
for
(
const
auto
&
span
:
manager_
->
singular_spans_
)
{
singular_operands_
.
push_back
(
MutableAlignedView
(
data
+
span
.
first
,
span
.
second
));
}
// Pre-allocate and parcel out stepwise operands.
stepwise_operands_
.
clear
();
stepwise_operands_
.
reserve
(
stepwise_spans_
.
size
());
stepwise_array_
.
Reserve
(
stepwise_stride_
*
pre_allocate_num_steps
);
data
=
stepwise_array_
.
view
().
data
();
for
(
const
auto
&
span
:
stepwise_spans_
)
{
stepwise_operands_
.
push_back
(
MutableAlignedArea
(
data
+
span
.
first
,
0
,
span
.
second
,
stepwise_stride_
));
}
// Create empty pairwise operands.
pairwise_operands_
.
clear
();
pairwise_operands_
.
resize
(
pairwise_sizes_
.
size
());
}
void
Operands
::
AddSteps
(
size_t
num_steps
)
{
AddStepwiseSteps
(
num_steps
);
AddPairwiseSteps
(
num_steps
);
}
void
Operands
::
AddStepwiseSteps
(
size_t
num_steps
)
{
if
(
stepwise_operands_
.
empty
())
return
;
// Make room for the new steps.
const
size_t
new_num_views
=
stepwise_operands_
[
0
].
num_views_
+
num_steps
;
const
bool
actually_reallocated
=
stepwise_array_
.
Resize
(
new_num_views
*
stepwise_stride_
);
// Update the base pointers for stepwise operands, if changed.
if
(
actually_reallocated
)
{
char
*
data
=
stepwise_array_
.
view
().
data
();
for
(
size_t
i
=
0
;
i
<
stepwise_operands_
.
size
();
++
i
)
{
stepwise_operands_
[
i
].
data_
=
data
+
stepwise_spans_
[
i
].
first
;
}
}
// Update the number of views in each stepwise operand.
for
(
MutableAlignedArea
&
operand
:
stepwise_operands_
)
{
operand
.
num_views_
=
new_num_views
;
}
}
void
Operands
::
AddPairwiseSteps
(
size_t
num_steps
)
{
if
(
pairwise_operands_
.
empty
())
return
;
const
size_t
new_num_steps
=
pairwise_operands_
[
0
].
num_views_
+
num_steps
;
// Set dimensions for each pairwise operand and accumulate their total stride.
size_t
new_stride
=
0
;
for
(
size_t
i
=
0
;
i
<
pairwise_operands_
.
size
();
++
i
)
{
const
size_t
new_view_size
=
new_num_steps
*
pairwise_sizes_
[
i
];
pairwise_operands_
[
i
].
num_views_
=
new_num_steps
;
pairwise_operands_
[
i
].
view_size_
=
new_view_size
;
new_stride
+=
PadToAlignment
(
new_view_size
);
}
// Note that Reset() does not preserve the existing array and its contents.
// Although preserving existing data would be nice, it is complex because
// pairwise operands grow in both dimensions. In addition, users should be
// allocating pairwise operands in one shot for speed reasons, in which case
// there is no existing data anyways.
pairwise_array_
.
Reset
(
new_num_steps
*
new_stride
);
// Set the new base pointer and stride on each pairwise operand.
char
*
data
=
pairwise_array_
.
view
().
data
();
for
(
MutableAlignedArea
&
operand
:
pairwise_operands_
)
{
operand
.
data_
=
data
;
operand
.
view_stride_
=
new_stride
;
data
+=
PadToAlignment
(
operand
.
view_size_
);
}
DCHECK_EQ
(
data
-
pairwise_array_
.
view
().
data
(),
new_stride
);
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/operands.h
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Utils for declaring and allocating operands. An operand is made up of
// aligned byte arrays, and can be used as an input, output, or intermediate
// value in some computation.
#ifndef DRAGNN_RUNTIME_OPERANDS_H_
#define DRAGNN_RUNTIME_OPERANDS_H_
#include <stddef.h>
#include <stdint.h>
#include <string>
#include <utility>
#include <vector>
#include "dragnn/runtime/alignment.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// Possible types of operands.
enum
class
OperandType
{
// A single byte array. For example, an intermediate value that is computed
// once per transition step. Since it is not an output, the same storage
// could be reused across all steps.
kSingular
,
// A sequence of identically-sized byte arrays, one per transition step. For
// example, a layer containing one activation vector per step.
kStepwise
,
// A grid with one byte array for each pair of transition steps, including
// self pairings. The byte arrays are grouped and concatenated in "rows",
// forming one byte array per step. For example, if there are N steps and D
// bytes per pair, the operand would have N arrays of size N*D bytes. In a
// basic attention model with one "similarity" between pairs of steps, one
// might use a pairwise operand with D=sizeof(float). For best performance,
// use Operands::AddSteps() to allocate all steps at once when working with
// pairwise operands.
kPairwise
,
};
// A specification of a operand.
struct
OperandSpec
{
// Creates a trivial specification.
OperandSpec
()
=
default
;
// Creates a specification with the |type| and |size|.
OperandSpec
(
OperandType
type
,
size_t
size
)
:
type
(
type
),
size
(
size
)
{}
// Type of the operand.
OperandType
type
=
OperandType
::
kSingular
;
// Size of each aligned byte array in the operand.
size_t
size
=
0
;
};
// An opaque handle to an operand.
class
OperandHandle
;
// A class that manages a set of operand specifications and associates each
// operand with a handle. Operand contents can be retrieved using these
// handles; see Operands below.
class
OperandManager
{
public:
// Creates an empty manager.
OperandManager
()
=
default
;
// Adds an operand configured according to the |spec| and returns its handle.
OperandHandle
Add
(
const
OperandSpec
&
spec
);
// Accessors.
const
OperandSpec
&
spec
(
OperandHandle
handle
)
const
;
private:
friend
class
Operands
;
// Specification of each operand.
std
::
vector
<
OperandSpec
>
specs_
;
// Mapping from the handle index of an operand to its index amongst operands
// of the same type.
std
::
vector
<
size_t
>
handle_index_to_typed_index_
;
// Span of each singular operand, as a (start-offset,size) pair, relative to
// the byte array containing all singular operands.
std
::
vector
<
std
::
pair
<
size_t
,
size_t
>>
singular_spans_
;
// Span of each stepwise operand, as a (start-offset,size) pair, relative to
// the byte array for each step.
std
::
vector
<
std
::
pair
<
size_t
,
size_t
>>
stepwise_spans_
;
// Size of each pairwise operand.
std
::
vector
<
size_t
>
pairwise_sizes_
;
// Number of bytes used by all singular operands, including alignment padding.
size_t
singular_size_
=
0
;
// Number of bytes used by all stepwise operands on each step, including
// alignment padding.
size_t
stepwise_stride_
=
0
;
};
// A set of operands. The structure of the operands is configured by an
// OperandManager, and operand contents can be accessed using the handles
// produced by the manager.
//
// Multiple Operands instances can share the same OperandManager. In addition,
// an Operands instance can be reused by repeatedly Reset()-ing it, potentially
// with different OperandManagers. Such reuse can reduce allocation overhead.
class
Operands
{
public:
// Creates an empty set.
Operands
()
=
default
;
// Resets this to the operands defined by the |manager|. The |manager| must
// live until this is destroyed or Reset() again, and should not be modified
// during that time. Stepwise and pairwise operands start with 0 steps; use
// AddStep() to extend them. Pre-allocates stepwise operands so that they
// will not be reallocated during the first |pre_allocate_num_steps| calls to
// AddStep(). Invalidates all previously-returned operands.
void
Reset
(
const
OperandManager
*
manager
,
size_t
pre_allocate_num_steps
);
// Extends stepwise and pairwise operands by one or more steps. Requires that
// Reset() was called. Invalidates any previously-returned views of stepwise
// and pairwise operands. Preserves data for pre-existing steps of stepwise
// operands, but not for pre-existing pairwise operands. In general, pairwise
// operands should be allocated in one shot, not incrementally.
void
AddStep
()
{
AddSteps
(
1
);
}
void
AddSteps
(
size_t
num_steps
);
// Returns the singular operand associated with the |handle|. The returned
// view is invalidated by Reset().
MutableAlignedView
GetSingular
(
OperandHandle
handle
)
const
;
// Returns the stepwise operand associated with the |handle|. The returned
// area is invalidated by Reset() and AddStep().
MutableAlignedArea
GetStepwise
(
OperandHandle
handle
)
const
;
// Returns the pairwise operand associated with the |handle|. The returned
// area is invalidated by Reset() and AddStep().
MutableAlignedArea
GetPairwise
(
OperandHandle
handle
)
const
;
private:
// Extends stepwise operands only; see AddSteps().
void
AddStepwiseSteps
(
size_t
num_steps
);
// Extends pairwise operands only; see AddSteps().
void
AddPairwiseSteps
(
size_t
num_steps
);
// Manager of the operands in this set.
const
OperandManager
*
manager_
=
nullptr
;
// Cached members from the |manager_|.
tensorflow
::
gtl
::
ArraySlice
<
size_t
>
handle_index_to_typed_index_
;
tensorflow
::
gtl
::
ArraySlice
<
std
::
pair
<
size_t
,
size_t
>>
stepwise_spans_
;
size_t
stepwise_stride_
=
0
;
tensorflow
::
gtl
::
ArraySlice
<
size_t
>
pairwise_sizes_
;
// Byte arrays holding operands of each type. Storage is separated because
// each type grows differently with the number of steps.
UniqueAlignedArray
singular_array_
;
UniqueAlignedArray
stepwise_array_
;
UniqueAlignedArray
pairwise_array_
;
// Lists of operands of each type.
std
::
vector
<
MutableAlignedView
>
singular_operands_
;
std
::
vector
<
MutableAlignedArea
>
stepwise_operands_
;
std
::
vector
<
MutableAlignedArea
>
pairwise_operands_
;
};
// Implementation details below.
// An opaque handle to an operand.
class
OperandHandle
{
public:
// Creates an invalid handle.
OperandHandle
()
=
default
;
private:
friend
class
OperandManager
;
friend
class
Operands
;
// Creates a handle that points to the |index|.
explicit
OperandHandle
(
size_t
index
)
:
index_
(
index
)
{}
// Index of the operand in its manager.
size_t
index_
=
SIZE_MAX
;
};
inline
const
OperandSpec
&
OperandManager
::
spec
(
OperandHandle
handle
)
const
{
return
specs_
[
handle
.
index_
];
}
inline
MutableAlignedView
Operands
::
GetSingular
(
OperandHandle
handle
)
const
{
DCHECK
(
manager_
->
spec
(
handle
).
type
==
OperandType
::
kSingular
)
<<
"Actual type: "
<<
static_cast
<
int
>
(
manager_
->
spec
(
handle
).
type
);
DCHECK_LE
(
handle
.
index_
,
handle_index_to_typed_index_
.
size
());
return
singular_operands_
[
handle_index_to_typed_index_
[
handle
.
index_
]];
}
inline
MutableAlignedArea
Operands
::
GetStepwise
(
OperandHandle
handle
)
const
{
DCHECK
(
manager_
->
spec
(
handle
).
type
==
OperandType
::
kStepwise
)
<<
"Actual type: "
<<
static_cast
<
int
>
(
manager_
->
spec
(
handle
).
type
);
DCHECK_LE
(
handle
.
index_
,
handle_index_to_typed_index_
.
size
());
return
stepwise_operands_
[
handle_index_to_typed_index_
[
handle
.
index_
]];
}
inline
MutableAlignedArea
Operands
::
GetPairwise
(
OperandHandle
handle
)
const
{
DCHECK
(
manager_
->
spec
(
handle
).
type
==
OperandType
::
kPairwise
)
<<
"Actual type: "
<<
static_cast
<
int
>
(
manager_
->
spec
(
handle
).
type
);
DCHECK_LE
(
handle
.
index_
,
handle_index_to_typed_index_
.
size
());
return
pairwise_operands_
[
handle_index_to_typed_index_
[
handle
.
index_
]];
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_OPERANDS_H_
research/syntaxnet/dragnn/runtime/operands_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/operands.h"
#include <string.h>
#include <tuple>
#include <utility>
#include <vector>
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/math/types.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Expects that the two pointers are the same.
void
ExpectSameAddress
(
const
void
*
pointer1
,
const
void
*
pointer2
)
{
EXPECT_EQ
(
pointer1
,
pointer2
);
}
// Sets the |vector| to |size| copies of the |value|.
template
<
class
T
>
void
Fill
(
MutableVector
<
T
>
vector
,
size_t
size
,
T
value
)
{
ASSERT_EQ
(
vector
.
size
(),
size
);
for
(
T
&
element
:
vector
)
element
=
value
;
}
// Expects that the |vector| contains |size| copies of the |expected_value|.
template
<
class
T
>
void
ExpectFilled
(
Vector
<
T
>
vector
,
size_t
size
,
T
expected_value
)
{
ASSERT_EQ
(
vector
.
size
(),
size
);
for
(
const
T
element
:
vector
)
EXPECT_EQ
(
element
,
expected_value
);
}
// Tests that OperandManager can add operands and remember their configuration.
TEST
(
OperandManagerTest
,
Add
)
{
OperandManager
manager
;
const
OperandHandle
handle1
=
manager
.
Add
({
OperandType
::
kSingular
,
7
});
const
OperandHandle
handle2
=
manager
.
Add
({
OperandType
::
kStepwise
,
11
});
EXPECT_EQ
(
manager
.
spec
(
handle1
).
type
,
OperandType
::
kSingular
);
EXPECT_EQ
(
manager
.
spec
(
handle1
).
size
,
7
);
EXPECT_EQ
(
manager
.
spec
(
handle2
).
type
,
OperandType
::
kStepwise
);
EXPECT_EQ
(
manager
.
spec
(
handle2
).
size
,
11
);
}
// Tests that Operands contains operands whose dimensions match its manager.
TEST
(
OperandsTest
,
Dimensions
)
{
const
size_t
kDim1
=
3
,
kDim2
=
41
,
kDim3
=
19
,
kDim4
=
77
;
OperandManager
manager
;
const
OperandHandle
handle1
=
manager
.
Add
({
OperandType
::
kSingular
,
kDim1
*
sizeof
(
float
)});
const
OperandHandle
handle2
=
manager
.
Add
({
OperandType
::
kStepwise
,
kDim2
*
sizeof
(
double
)});
const
OperandHandle
handle3
=
manager
.
Add
({
OperandType
::
kSingular
,
kDim3
*
sizeof
(
float
)});
const
OperandHandle
handle4
=
manager
.
Add
({
OperandType
::
kStepwise
,
kDim4
*
sizeof
(
int
)});
AlignedView
view
;
AlignedArea
area
;
Operands
operands
;
operands
.
Reset
(
&
manager
,
10
);
view
=
operands
.
GetSingular
(
handle1
);
EXPECT_EQ
(
view
.
size
(),
kDim1
*
sizeof
(
float
));
EXPECT_EQ
(
Vector
<
float
>
(
view
).
size
(),
kDim1
);
area
=
operands
.
GetStepwise
(
handle2
);
EXPECT_EQ
(
area
.
num_views
(),
0
);
// no steps yet
EXPECT_EQ
(
area
.
view_size
(),
kDim2
*
sizeof
(
double
));
EXPECT_EQ
(
Matrix
<
double
>
(
area
).
num_rows
(),
0
);
// starts with no steps
EXPECT_EQ
(
Matrix
<
double
>
(
area
).
num_columns
(),
kDim2
);
view
=
operands
.
GetSingular
(
handle3
);
EXPECT_EQ
(
view
.
size
(),
kDim3
*
sizeof
(
float
));
EXPECT_EQ
(
Vector
<
float
>
(
view
).
size
(),
kDim3
);
area
=
operands
.
GetStepwise
(
handle4
);
EXPECT_EQ
(
area
.
num_views
(),
0
);
// no steps yet
EXPECT_EQ
(
area
.
view_size
(),
kDim4
*
sizeof
(
int
));
EXPECT_EQ
(
Matrix
<
int
>
(
area
).
num_rows
(),
0
);
// starts with no steps
EXPECT_EQ
(
Matrix
<
int
>
(
area
).
num_columns
(),
kDim4
);
}
// Tests that Operands can incrementally extend stepwise operands while
// preserving existing values.
TEST
(
OperandsTest
,
AddStepToStepwise
)
{
const
size_t
kDim1
=
23
,
kDim2
=
29
;
OperandManager
manager
;
const
OperandHandle
handle1
=
manager
.
Add
({
OperandType
::
kStepwise
,
kDim1
*
sizeof
(
double
)});
const
OperandHandle
handle2
=
manager
.
Add
({
OperandType
::
kStepwise
,
kDim2
*
sizeof
(
int
)});
Operands
operands
;
operands
.
Reset
(
&
manager
,
10
);
// Repeatedly add a step and fill it with values.
for
(
int
i
=
0
;
i
<
100
;
++
i
)
{
operands
.
AddStep
();
Fill
(
MutableVector
<
double
>
(
operands
.
GetStepwise
(
handle1
).
view
(
i
)),
kDim1
,
1000.0
+
i
);
Fill
(
MutableVector
<
int
>
(
operands
.
GetStepwise
(
handle2
).
view
(
i
)),
kDim2
,
2000
+
i
);
}
// Check that data from earlier steps is preserved across reallocations.
for
(
int
i
=
0
;
i
<
100
;
++
i
)
{
ExpectFilled
(
Vector
<
double
>
(
operands
.
GetStepwise
(
handle1
).
view
(
i
)),
kDim1
,
1000.0
+
i
);
ExpectFilled
(
Vector
<
int
>
(
operands
.
GetStepwise
(
handle2
).
view
(
i
)),
kDim2
,
2000
+
i
);
}
}
// Tests that Operands can add multiple steps at once.
TEST
(
OperandsTest
,
AddStepsToStepwise
)
{
const
size_t
kDim1
=
23
,
kDim2
=
29
;
OperandManager
manager
;
const
OperandHandle
handle1
=
manager
.
Add
({
OperandType
::
kStepwise
,
kDim1
*
sizeof
(
double
)});
const
OperandHandle
handle2
=
manager
.
Add
({
OperandType
::
kStepwise
,
kDim2
*
sizeof
(
int
)});
Operands
operands
;
operands
.
Reset
(
&
manager
,
10
);
// Repeatedly add blocks of steps and fill them with values.
for
(
int
i
=
0
;
i
<
100
;
++
i
)
{
if
(
i
%
10
==
0
)
operands
.
AddSteps
(
10
);
// occasionally add a block
Fill
(
MutableVector
<
double
>
(
operands
.
GetStepwise
(
handle1
).
view
(
i
)),
kDim1
,
1000.0
+
i
);
Fill
(
MutableVector
<
int
>
(
operands
.
GetStepwise
(
handle2
).
view
(
i
)),
kDim2
,
2000
+
i
);
}
// Check that data from earlier steps is preserved across reallocations.
for
(
int
i
=
0
;
i
<
100
;
++
i
)
{
ExpectFilled
(
Vector
<
double
>
(
operands
.
GetStepwise
(
handle1
).
view
(
i
)),
kDim1
,
1000.0
+
i
);
ExpectFilled
(
Vector
<
int
>
(
operands
.
GetStepwise
(
handle2
).
view
(
i
)),
kDim2
,
2000
+
i
);
}
}
// Tests that Operands can add multiple steps to a pairwise operand.
TEST
(
OperandsTest
,
AddStepsPairwise
)
{
const
size_t
kDim1
=
4
,
kDim2
=
31
;
OperandManager
manager
;
const
OperandHandle
handle1
=
manager
.
Add
({
OperandType
::
kPairwise
,
kDim1
});
const
OperandHandle
handle2
=
manager
.
Add
({
OperandType
::
kPairwise
,
kDim2
});
Operands
operands
;
operands
.
Reset
(
&
manager
,
10
);
{
// A 1x1 pairwise operand.
operands
.
AddSteps
(
1
);
const
MutableAlignedArea
area1
=
operands
.
GetPairwise
(
handle1
);
const
MutableAlignedArea
area2
=
operands
.
GetPairwise
(
handle2
);
EXPECT_EQ
(
area1
.
num_views
(),
1
);
EXPECT_EQ
(
area2
.
num_views
(),
1
);
EXPECT_EQ
(
area1
.
view_size
(),
kDim1
);
EXPECT_EQ
(
area2
.
view_size
(),
kDim2
);
// Write to operands to test the validity of the underlying memory region.
memset
(
area1
.
view
(
0
).
data
(),
0
,
kDim1
);
memset
(
area2
.
view
(
0
).
data
(),
0
,
kDim2
);
}
{
// A 10x10 pairwise operand.
operands
.
AddSteps
(
9
);
const
MutableAlignedArea
area1
=
operands
.
GetPairwise
(
handle1
);
const
MutableAlignedArea
area2
=
operands
.
GetPairwise
(
handle2
);
EXPECT_EQ
(
area1
.
num_views
(),
10
);
EXPECT_EQ
(
area2
.
num_views
(),
10
);
EXPECT_EQ
(
area1
.
view_size
(),
10
*
kDim1
);
EXPECT_EQ
(
area2
.
view_size
(),
10
*
kDim2
);
// Infer the stride by comparing pointers between consecutive views.
const
size_t
expected_stride
=
PadToAlignment
(
10
*
kDim1
)
+
PadToAlignment
(
10
*
kDim2
);
EXPECT_EQ
(
area1
.
view
(
1
).
data
()
-
area1
.
view
(
0
).
data
(),
expected_stride
);
EXPECT_EQ
(
area2
.
view
(
1
).
data
()
-
area2
.
view
(
0
).
data
(),
expected_stride
);
// Write to operands to test the validity of the underlying memory region.
memset
(
area1
.
view
(
9
).
data
(),
0
,
10
*
kDim1
);
memset
(
area2
.
view
(
9
).
data
(),
0
,
10
*
kDim2
);
}
}
// Tests that Operands can be reused by resetting them repeatedly, possibly
// switching between different managers.
TEST
(
OperandsTest
,
ResetWithDifferentManagers
)
{
std
::
vector
<
OperandManager
>
managers
;
std
::
vector
<
std
::
tuple
<
OperandHandle
,
OperandHandle
,
OperandHandle
>>
handles
;
for
(
int
dim
=
0
;
dim
<
10
;
++
dim
)
{
managers
.
emplace_back
();
handles
.
emplace_back
(
managers
.
back
().
Add
({
OperandType
::
kSingular
,
dim
*
sizeof
(
double
)}),
managers
.
back
().
Add
({
OperandType
::
kStepwise
,
dim
*
sizeof
(
int
)}),
managers
.
back
().
Add
({
OperandType
::
kPairwise
,
dim
*
sizeof
(
float
)}));
}
Operands
operands
;
for
(
int
trial
=
0
;
trial
<
10
;
++
trial
)
{
for
(
int
dim
=
0
;
dim
<
10
;
++
dim
)
{
operands
.
Reset
(
&
managers
[
dim
],
10
);
const
OperandHandle
singular_handle
=
std
::
get
<
0
>
(
handles
[
dim
]);
const
OperandHandle
stepwise_handle
=
std
::
get
<
1
>
(
handles
[
dim
]);
const
OperandHandle
pairwise_handle
=
std
::
get
<
2
>
(
handles
[
dim
]);
// Fill the singular operand.
Fill
(
MutableVector
<
double
>
(
operands
.
GetSingular
(
singular_handle
)),
dim
,
100.0
*
trial
+
dim
);
// Check the singular operands.
ExpectFilled
(
Vector
<
double
>
(
operands
.
GetSingular
(
singular_handle
)),
dim
,
100.0
*
trial
+
dim
);
// Repeatedly add a step and fill it with values.
for
(
int
step
=
0
;
step
<
100
;
++
step
)
{
operands
.
AddStep
();
Fill
(
MutableVector
<
int
>
(
operands
.
GetStepwise
(
stepwise_handle
).
view
(
step
)),
dim
,
1000
*
trial
+
100
*
dim
+
step
);
}
// Check that data from earlier steps is preserved across reallocations.
for
(
int
step
=
0
;
step
<
100
;
++
step
)
{
ExpectFilled
(
Vector
<
int
>
(
operands
.
GetStepwise
(
stepwise_handle
).
view
(
step
)),
dim
,
1000
*
trial
+
100
*
dim
+
step
);
}
// Check the dimensions of pairwise operands.
Matrix
<
float
>
pairwise
(
operands
.
GetPairwise
(
pairwise_handle
));
EXPECT_EQ
(
pairwise
.
num_rows
(),
100
);
EXPECT_EQ
(
pairwise
.
num_columns
(),
100
*
dim
);
}
}
}
// Tests that one OperandManager can be shared simultaneously between multiple
// Operands instances.
TEST
(
OperandsTest
,
SharedManager
)
{
const
size_t
kDim
=
17
;
OperandManager
manager
;
const
OperandHandle
singular_handle
=
manager
.
Add
({
OperandType
::
kSingular
,
kDim
*
sizeof
(
double
)});
const
OperandHandle
stepwise_handle
=
manager
.
Add
({
OperandType
::
kStepwise
,
kDim
*
sizeof
(
int
)});
std
::
vector
<
Operands
>
operands_vec
(
10
);
for
(
Operands
&
operands
:
operands_vec
)
operands
.
Reset
(
&
manager
,
10
);
// Fill all singular operands.
for
(
int
trial
=
0
;
trial
<
operands_vec
.
size
();
++
trial
)
{
const
Operands
&
operands
=
operands_vec
[
trial
];
Fill
(
MutableVector
<
double
>
(
operands
.
GetSingular
(
singular_handle
)),
kDim
,
3.0
*
trial
);
}
// Check all singular operands.
for
(
int
trial
=
0
;
trial
<
operands_vec
.
size
();
++
trial
)
{
const
Operands
&
operands
=
operands_vec
[
trial
];
ExpectFilled
(
Vector
<
double
>
(
operands
.
GetSingular
(
singular_handle
)),
kDim
,
3.0
*
trial
);
}
// Fill all stepwise operands. Interleave operations on the operands on each
// step, so all operands are "active" at the same time.
for
(
int
step
=
0
;
step
<
100
;
++
step
)
{
for
(
int
trial
=
0
;
trial
<
10
;
++
trial
)
{
Operands
&
operands
=
operands_vec
[
trial
];
operands
.
AddStep
();
Fill
(
MutableVector
<
int
>
(
operands
.
GetStepwise
(
stepwise_handle
).
view
(
step
)),
kDim
,
trial
*
999
+
step
);
}
}
// Check all stepwise operands.
for
(
int
step
=
0
;
step
<
100
;
++
step
)
{
for
(
int
trial
=
0
;
trial
<
10
;
++
trial
)
{
const
Operands
&
operands
=
operands_vec
[
trial
];
ExpectFilled
(
Vector
<
int
>
(
operands
.
GetStepwise
(
stepwise_handle
).
view
(
step
)),
kDim
,
trial
*
999
+
step
);
}
}
}
// Tests that an Operands uses all of the pre-allocated steps and reallocates
// exactly when it exhausts the pre-allocated array.
TEST
(
OperandsTest
,
UsesPreAllocatedSteps
)
{
const
size_t
kBytes
=
5
;
const
size_t
kPreAllocateNumSteps
=
10
;
OperandManager
manager
;
const
OperandHandle
handle
=
manager
.
Add
({
OperandType
::
kStepwise
,
kBytes
});
Operands
operands
;
operands
.
Reset
(
&
manager
,
kPreAllocateNumSteps
);
// The first N steps fit exactly in the pre-allocated array. Access the base
// of the stepwise array via the first view.
operands
.
AddStep
();
char
*
const
pre_allocated_data
=
operands
.
GetStepwise
(
handle
).
view
(
0
).
data
();
for
(
size_t
step
=
1
;
step
<
kPreAllocateNumSteps
;
++
step
)
{
operands
.
AddStep
();
ASSERT_EQ
(
operands
.
GetStepwise
(
handle
).
view
(
0
).
data
(),
pre_allocated_data
);
}
// The N+1'st step triggers a reallocation, which is guaranteed to yield a new
// pointer because it creates a separate array and copies into it.
operands
.
AddStep
();
ASSERT_NE
(
operands
.
GetStepwise
(
handle
).
view
(
0
).
data
(),
pre_allocated_data
);
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/recurrent_sequence_linkers.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/sequence_linker.h"
#include "dragnn/runtime/transition_system_traits.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Links to the previous step in the same component. Templated on a bool that
// indicates the direction that the transition system runs in.
template
<
bool
left_to_right
>
class
RecurrentSequenceLinker
:
public
SequenceLinker
{
public:
// Implements SequenceLinker.
bool
Supports
(
const
LinkedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
)
const
override
;
tensorflow
::
Status
Initialize
(
const
LinkedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
)
override
;
tensorflow
::
Status
GetLinks
(
size_t
source_num_steps
,
InputBatchCache
*
input
,
std
::
vector
<
int32
>
*
links
)
const
override
;
};
template
<
bool
left_to_right
>
bool
RecurrentSequenceLinker
<
left_to_right
>::
Supports
(
const
LinkedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
)
const
{
TransitionSystemTraits
traits
(
component_spec
);
// Here, fml="bias" and source_translator="history" are a DRAGNN recipe for
// linking to the previous transition step. More concretely,
// * "bias" always extracts index 0.
// * "history" subtracts the index it is given from (#steps - 1).
// Putting the two together, we link to (#steps - 1 - 0); i.e., the previous
// transition step.
return
(
channel
.
fml
()
==
"bias"
||
channel
.
fml
()
==
"bias(0)"
)
&&
channel
.
source_component
()
==
component_spec
.
name
()
&&
channel
.
source_translator
()
==
"history"
&&
traits
.
is_left_to_right
==
left_to_right
&&
traits
.
is_sequential
;
}
template
<
bool
left_to_right
>
tensorflow
::
Status
RecurrentSequenceLinker
<
left_to_right
>::
Initialize
(
const
LinkedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
)
{
return
tensorflow
::
Status
::
OK
();
}
template
<
bool
left_to_right
>
tensorflow
::
Status
RecurrentSequenceLinker
<
left_to_right
>::
GetLinks
(
size_t
source_num_steps
,
InputBatchCache
*
input
,
std
::
vector
<
int32
>
*
links
)
const
{
links
->
resize
(
source_num_steps
);
if
(
left_to_right
)
{
int32
index
=
-
1
;
for
(
int32
&
link
:
*
links
)
link
=
index
++
;
}
else
{
int32
index
=
static_cast
<
int32
>
(
source_num_steps
)
-
1
;
for
(
int32
&
link
:
*
links
)
link
=
--
index
;
}
return
tensorflow
::
Status
::
OK
();
}
using
LeftToRightRecurrentSequenceLinker
=
RecurrentSequenceLinker
<
true
>
;
using
RightToLeftRecurrentSequenceLinker
=
RecurrentSequenceLinker
<
false
>
;
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER
(
LeftToRightRecurrentSequenceLinker
);
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER
(
RightToLeftRecurrentSequenceLinker
);
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/recurrent_sequence_linkers_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/sequence_linker.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Returns a ComponentSpec that the linker will support.
ComponentSpec
MakeSupportedSpec
()
{
ComponentSpec
component_spec
;
component_spec
.
set_name
(
"test_component"
);
component_spec
.
mutable_transition_system
()
->
set_registered_name
(
"shift-only"
);
LinkedFeatureChannel
*
channel
=
component_spec
.
add_linked_feature
();
channel
->
set_fml
(
"bias"
);
channel
->
set_source_component
(
"test_component"
);
channel
->
set_source_translator
(
"history"
);
return
component_spec
;
}
// Tests that the linker supports appropriate specs.
TEST
(
RecurrentSequenceLinkerTest
,
Supported
)
{
string
name
;
ComponentSpec
component_spec
=
MakeSupportedSpec
();
LinkedFeatureChannel
&
channel
=
*
component_spec
.
mutable_linked_feature
(
0
);
TF_ASSERT_OK
(
SequenceLinker
::
Select
(
channel
,
component_spec
,
&
name
));
EXPECT_EQ
(
name
,
"LeftToRightRecurrentSequenceLinker"
);
(
*
component_spec
.
mutable_transition_system
()
->
mutable_parameters
())[
"left_to_right"
]
=
"false"
;
TF_ASSERT_OK
(
SequenceLinker
::
Select
(
channel
,
component_spec
,
&
name
));
EXPECT_EQ
(
name
,
"RightToLeftRecurrentSequenceLinker"
);
channel
.
set_fml
(
"bias(0)"
);
TF_ASSERT_OK
(
SequenceLinker
::
Select
(
channel
,
component_spec
,
&
name
));
EXPECT_EQ
(
name
,
"RightToLeftRecurrentSequenceLinker"
);
(
*
component_spec
.
mutable_transition_system
()
->
mutable_parameters
())[
"left_to_right"
]
=
"true"
;
TF_ASSERT_OK
(
SequenceLinker
::
Select
(
channel
,
component_spec
,
&
name
));
EXPECT_EQ
(
name
,
"LeftToRightRecurrentSequenceLinker"
);
}
// Tests that the linker requires the right transition system.
TEST
(
RecurrentSequenceLinkerTest
,
WrongTransitionSystem
)
{
string
name
;
ComponentSpec
component_spec
=
MakeSupportedSpec
();
const
LinkedFeatureChannel
&
channel
=
component_spec
.
linked_feature
(
0
);
component_spec
.
mutable_transition_system
()
->
set_registered_name
(
"bad"
);
EXPECT_THAT
(
SequenceLinker
::
Select
(
channel
,
component_spec
,
&
name
),
test
::
IsErrorWithSubstr
(
"No SequenceLinker supports channel"
));
}
// Tests that the linker requires the right FML.
TEST
(
RecurrentSequenceLinkerTest
,
WrongFml
)
{
string
name
;
ComponentSpec
component_spec
=
MakeSupportedSpec
();
LinkedFeatureChannel
&
channel
=
*
component_spec
.
mutable_linked_feature
(
0
);
channel
.
set_fml
(
"bad"
);
EXPECT_THAT
(
SequenceLinker
::
Select
(
channel
,
component_spec
,
&
name
),
test
::
IsErrorWithSubstr
(
"No SequenceLinker supports channel"
));
}
// Tests that the linker requires a recurrent link.
TEST
(
RecurrentSequenceLinkerTest
,
WrongSource
)
{
string
name
;
ComponentSpec
component_spec
=
MakeSupportedSpec
();
LinkedFeatureChannel
&
channel
=
*
component_spec
.
mutable_linked_feature
(
0
);
channel
.
set_source_component
(
"bad"
);
EXPECT_THAT
(
SequenceLinker
::
Select
(
channel
,
component_spec
,
&
name
),
test
::
IsErrorWithSubstr
(
"No SequenceLinker supports channel"
));
}
// Tests that the linker requires the right translator.
TEST
(
RecurrentSequenceLinkerTest
,
WrongTranslator
)
{
string
name
;
ComponentSpec
component_spec
=
MakeSupportedSpec
();
LinkedFeatureChannel
&
channel
=
*
component_spec
.
mutable_linked_feature
(
0
);
channel
.
set_source_translator
(
"bad"
);
EXPECT_THAT
(
SequenceLinker
::
Select
(
channel
,
component_spec
,
&
name
),
test
::
IsErrorWithSubstr
(
"No SequenceLinker supports channel"
));
}
// Tests that the linker can be initialized and used to extract links.
TEST
(
RecurrentSequenceLinkerTest
,
InitializeAndGetLinks
)
{
const
ComponentSpec
component_spec
=
MakeSupportedSpec
();
const
LinkedFeatureChannel
&
channel
=
component_spec
.
linked_feature
(
0
);
std
::
unique_ptr
<
SequenceLinker
>
linker
;
TF_ASSERT_OK
(
SequenceLinker
::
New
(
"LeftToRightRecurrentSequenceLinker"
,
channel
,
component_spec
,
&
linker
));
InputBatchCache
input
;
std
::
vector
<
int32
>
links
=
{
123
,
456
,
789
};
// gets overwritten
TF_ASSERT_OK
(
linker
->
GetLinks
(
10
,
&
input
,
&
links
));
const
std
::
vector
<
int32
>
expected_links
=
{
-
1
,
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
};
EXPECT_EQ
(
links
,
expected_links
);
}
// Tests that the links are reversed for right-to-left components.
TEST
(
RecurrentSequenceLinkerTest
,
InitializeAndGetLinksRightToLeft
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
const
LinkedFeatureChannel
&
channel
=
component_spec
.
linked_feature
(
0
);
std
::
unique_ptr
<
SequenceLinker
>
linker
;
TF_ASSERT_OK
(
SequenceLinker
::
New
(
"RightToLeftRecurrentSequenceLinker"
,
channel
,
component_spec
,
&
linker
));
InputBatchCache
input
;
std
::
vector
<
int32
>
links
=
{
123
,
456
,
789
};
// gets overwritten
TF_ASSERT_OK
(
linker
->
GetLinks
(
10
,
&
input
,
&
links
));
const
std
::
vector
<
int32
>
expected_links
=
{
8
,
7
,
6
,
5
,
4
,
3
,
2
,
1
,
0
,
-
1
};
EXPECT_EQ
(
links
,
expected_links
);
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/reversed_sequence_linker.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/sequence_linker.h"
#include "dragnn/runtime/transition_system_traits.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Applies a reversed identity function.
class
ReversedSequenceLinker
:
public
SequenceLinker
{
public:
// Implements SequenceLinker.
bool
Supports
(
const
LinkedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
)
const
override
;
tensorflow
::
Status
Initialize
(
const
LinkedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
)
override
;
tensorflow
::
Status
GetLinks
(
size_t
source_num_steps
,
InputBatchCache
*
input
,
std
::
vector
<
int32
>
*
links
)
const
override
;
};
bool
ReversedSequenceLinker
::
Supports
(
const
LinkedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
)
const
{
TransitionSystemTraits
traits
(
component_spec
);
// Note: Add more "||" clauses as needed.
return
((
channel
.
fml
()
==
"input.focus"
&&
channel
.
source_translator
()
==
"reverse-token"
)
||
(
channel
.
fml
()
==
"char-input.focus"
&&
channel
.
source_translator
()
==
"reverse-char"
))
&&
traits
.
is_sequential
;
}
tensorflow
::
Status
ReversedSequenceLinker
::
Initialize
(
const
LinkedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
)
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
ReversedSequenceLinker
::
GetLinks
(
size_t
source_num_steps
,
InputBatchCache
*
input
,
std
::
vector
<
int32
>
*
links
)
const
{
links
->
resize
(
source_num_steps
);
int32
index
=
links
->
size
();
for
(
int32
&
link
:
*
links
)
link
=
--
index
;
return
tensorflow
::
Status
::
OK
();
}
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER
(
ReversedSequenceLinker
);
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/reversed_sequence_linker_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/sequence_linker.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Returns a ComponentSpec that the linker will support.
ComponentSpec
MakeSupportedSpec
()
{
ComponentSpec
component_spec
;
component_spec
.
mutable_transition_system
()
->
set_registered_name
(
"shift-only"
);
LinkedFeatureChannel
*
channel
=
component_spec
.
add_linked_feature
();
channel
->
set_fml
(
"input.focus"
);
channel
->
set_source_translator
(
"reverse-token"
);
return
component_spec
;
}
// Tests that the linker supports appropriate specs.
TEST
(
ReversedSequenceLinkerTest
,
Supported
)
{
string
name
;
ComponentSpec
component_spec
=
MakeSupportedSpec
();
LinkedFeatureChannel
&
channel
=
*
component_spec
.
mutable_linked_feature
(
0
);
TF_ASSERT_OK
(
SequenceLinker
::
Select
(
channel
,
component_spec
,
&
name
));
EXPECT_EQ
(
name
,
"ReversedSequenceLinker"
);
channel
.
set_fml
(
"char-input.focus"
);
channel
.
set_source_translator
(
"reverse-char"
);
TF_ASSERT_OK
(
SequenceLinker
::
Select
(
channel
,
component_spec
,
&
name
));
EXPECT_EQ
(
name
,
"ReversedSequenceLinker"
);
}
// Tests that the linker requires the right transition system.
TEST
(
IdentitySequenceLinkerTest
,
WrongTransitionSystem
)
{
string
name
;
ComponentSpec
component_spec
=
MakeSupportedSpec
();
const
LinkedFeatureChannel
&
channel
=
component_spec
.
linked_feature
(
0
);
component_spec
.
mutable_transition_system
()
->
set_registered_name
(
"bad"
);
EXPECT_THAT
(
SequenceLinker
::
Select
(
channel
,
component_spec
,
&
name
),
test
::
IsErrorWithSubstr
(
"No SequenceLinker supports channel"
));
}
// Tests that the linker requires the right FML.
TEST
(
ReversedSequenceLinkerTest
,
WrongFml
)
{
string
name
;
ComponentSpec
component_spec
=
MakeSupportedSpec
();
LinkedFeatureChannel
&
channel
=
*
component_spec
.
mutable_linked_feature
(
0
);
channel
.
set_fml
(
"bad"
);
EXPECT_THAT
(
SequenceLinker
::
Select
(
channel
,
component_spec
,
&
name
),
test
::
IsErrorWithSubstr
(
"No SequenceLinker supports channel"
));
}
// Tests that the linker requires the right translator.
TEST
(
ReversedSequenceLinkerTest
,
WrongTranslator
)
{
string
name
;
ComponentSpec
component_spec
=
MakeSupportedSpec
();
LinkedFeatureChannel
&
channel
=
*
component_spec
.
mutable_linked_feature
(
0
);
channel
.
set_source_translator
(
"bad"
);
EXPECT_THAT
(
SequenceLinker
::
Select
(
channel
,
component_spec
,
&
name
),
test
::
IsErrorWithSubstr
(
"No SequenceLinker supports channel"
));
}
// Tests that the linker requires the right combination of FML and translator.
TEST
(
ReversedSequenceLinkerTest
,
MismatchedFmlAndTranslator
)
{
string
name
;
ComponentSpec
component_spec
=
MakeSupportedSpec
();
LinkedFeatureChannel
&
channel
=
*
component_spec
.
mutable_linked_feature
(
0
);
channel
.
set_fml
(
"input.focus"
);
channel
.
set_source_translator
(
"reverse-char"
);
EXPECT_THAT
(
SequenceLinker
::
Select
(
channel
,
component_spec
,
&
name
),
test
::
IsErrorWithSubstr
(
"No SequenceLinker supports channel"
));
channel
.
set_fml
(
"char-input.focus"
);
channel
.
set_source_translator
(
"reverse-token"
);
EXPECT_THAT
(
SequenceLinker
::
Select
(
channel
,
component_spec
,
&
name
),
test
::
IsErrorWithSubstr
(
"No SequenceLinker supports channel"
));
}
// Tests that the linker can be initialized and used to extract links.
TEST
(
ReversedSequenceLinkerTest
,
InitializeAndGetLinks
)
{
const
ComponentSpec
component_spec
=
MakeSupportedSpec
();
const
LinkedFeatureChannel
&
channel
=
component_spec
.
linked_feature
(
0
);
std
::
unique_ptr
<
SequenceLinker
>
linker
;
TF_ASSERT_OK
(
SequenceLinker
::
New
(
"ReversedSequenceLinker"
,
channel
,
component_spec
,
&
linker
));
InputBatchCache
input
;
std
::
vector
<
int32
>
links
=
{
123
,
456
,
789
};
// gets overwritten
TF_ASSERT_OK
(
linker
->
GetLinks
(
10
,
&
input
,
&
links
));
const
std
::
vector
<
int32
>
expected_links
=
{
9
,
8
,
7
,
6
,
5
,
4
,
3
,
2
,
1
,
0
};
EXPECT_EQ
(
links
,
expected_links
);
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/select_best_component_transformer.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <string>
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/component_transformation.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Transformer that selects the best component subclass for the ComponentSpec.
class
SelectBestComponentTransformer
:
public
ComponentTransformer
{
public:
// Implements ComponentTransformer.
tensorflow
::
Status
Transform
(
const
string
&
component_type
,
ComponentSpec
*
component_spec
)
override
{
string
best_component_type
;
TF_RETURN_IF_ERROR
(
Component
::
Select
(
*
component_spec
,
&
best_component_type
));
component_spec
->
mutable_component_builder
()
->
set_registered_name
(
best_component_type
);
if
(
component_type
!=
best_component_type
)
{
LOG
(
INFO
)
<<
"Component '"
<<
component_spec
->
name
()
<<
"' builder updated from "
<<
component_type
<<
" to "
<<
best_component_type
<<
"."
;
}
else
{
VLOG
(
2
)
<<
"Component '"
<<
component_spec
->
name
()
<<
"' builder type "
<<
component_type
<<
" unchanged."
;
}
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_COMPONENT_TRANSFORMER
(
SelectBestComponentTransformer
);
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
Prev
1
…
4
5
6
7
8
9
10
11
12
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment