Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
edea2b67
Commit
edea2b67
authored
May 11, 2018
by
Terry Koo
Browse files
Remove runtime because reasons.
parent
a4bb31d0
Changes
291
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
3572 deletions
+0
-3572
research/syntaxnet/dragnn/runtime/select_best_component_transformer_test.cc
.../dragnn/runtime/select_best_component_transformer_test.cc
+0
-118
research/syntaxnet/dragnn/runtime/sequence_backend.cc
research/syntaxnet/dragnn/runtime/sequence_backend.cc
+0
-152
research/syntaxnet/dragnn/runtime/sequence_backend.h
research/syntaxnet/dragnn/runtime/sequence_backend.h
+0
-124
research/syntaxnet/dragnn/runtime/sequence_backend_test.cc
research/syntaxnet/dragnn/runtime/sequence_backend_test.cc
+0
-172
research/syntaxnet/dragnn/runtime/sequence_bulk_dynamic_component.cc
...ntaxnet/dragnn/runtime/sequence_bulk_dynamic_component.cc
+0
-195
research/syntaxnet/dragnn/runtime/sequence_bulk_dynamic_component_test.cc
...et/dragnn/runtime/sequence_bulk_dynamic_component_test.cc
+0
-311
research/syntaxnet/dragnn/runtime/sequence_component_transformer.cc
...yntaxnet/dragnn/runtime/sequence_component_transformer.cc
+0
-144
research/syntaxnet/dragnn/runtime/sequence_component_transformer_test.cc
...net/dragnn/runtime/sequence_component_transformer_test.cc
+0
-261
research/syntaxnet/dragnn/runtime/sequence_extractor.cc
research/syntaxnet/dragnn/runtime/sequence_extractor.cc
+0
-75
research/syntaxnet/dragnn/runtime/sequence_extractor.h
research/syntaxnet/dragnn/runtime/sequence_extractor.h
+0
-100
research/syntaxnet/dragnn/runtime/sequence_extractor_test.cc
research/syntaxnet/dragnn/runtime/sequence_extractor_test.cc
+0
-166
research/syntaxnet/dragnn/runtime/sequence_features.cc
research/syntaxnet/dragnn/runtime/sequence_features.cc
+0
-104
research/syntaxnet/dragnn/runtime/sequence_features.h
research/syntaxnet/dragnn/runtime/sequence_features.h
+0
-159
research/syntaxnet/dragnn/runtime/sequence_features_test.cc
research/syntaxnet/dragnn/runtime/sequence_features_test.cc
+0
-346
research/syntaxnet/dragnn/runtime/sequence_linker.cc
research/syntaxnet/dragnn/runtime/sequence_linker.cc
+0
-74
research/syntaxnet/dragnn/runtime/sequence_linker.h
research/syntaxnet/dragnn/runtime/sequence_linker.h
+0
-105
research/syntaxnet/dragnn/runtime/sequence_linker_test.cc
research/syntaxnet/dragnn/runtime/sequence_linker_test.cc
+0
-167
research/syntaxnet/dragnn/runtime/sequence_links.cc
research/syntaxnet/dragnn/runtime/sequence_links.cc
+0
-146
research/syntaxnet/dragnn/runtime/sequence_links.h
research/syntaxnet/dragnn/runtime/sequence_links.h
+0
-169
research/syntaxnet/dragnn/runtime/sequence_links_test.cc
research/syntaxnet/dragnn/runtime/sequence_links_test.cc
+0
-484
No files found.
Too many changes to show.
To preserve performance only
291 of 291+
files are displayed.
Plain diff
Email patch
research/syntaxnet/dragnn/runtime/select_best_component_transformer_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <memory>
#include <string>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/component_transformation.h"
#include "dragnn/runtime/extensions.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Base class for test components.
class
TestComponentBase
:
public
Component
{
public:
// Partially implements Component.
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
,
VariableStore
*
,
NetworkStateManager
*
,
ExtensionManager
*
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
Evaluate
(
SessionState
*
,
ComputeSession
*
,
ComponentTrace
*
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
bool
PreferredTo
(
const
Component
&
)
const
override
{
return
false
;
}
};
// Supports components whose builder name includes "Foo".
class
ContainsFoo
:
public
TestComponentBase
{
public:
// Implements Component.
bool
Supports
(
const
ComponentSpec
&
,
const
string
&
normalized_builder_name
)
const
override
{
return
normalized_builder_name
.
find
(
"Foo"
)
!=
string
::
npos
;
}
};
DRAGNN_RUNTIME_REGISTER_COMPONENT
(
ContainsFoo
);
// Supports components whose builder name includes "Bar".
class
ContainsBar
:
public
TestComponentBase
{
public:
// Implements Component.
bool
Supports
(
const
ComponentSpec
&
,
const
string
&
normalized_builder_name
)
const
override
{
return
normalized_builder_name
.
find
(
"Bar"
)
!=
string
::
npos
;
}
};
DRAGNN_RUNTIME_REGISTER_COMPONENT
(
ContainsBar
);
// Tests that a spec with an unknown builder name causes an error.
TEST
(
SelectBestComponentTransformerTest
,
Unknown
)
{
ComponentSpec
component_spec
;
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"unknown"
);
EXPECT_THAT
(
ComponentTransformer
::
ApplyAll
(
&
component_spec
),
test
::
IsErrorWithSubstr
(
"Could not find a best"
));
}
// Tests that a spec with builder "Foo" is changed to "ContainsFoo".
TEST
(
SelectBestComponentTransformerTest
,
ChangeToContainsFoo
)
{
ComponentSpec
component_spec
;
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"Foo"
);
ComponentSpec
expected_spec
=
component_spec
;
expected_spec
.
mutable_component_builder
()
->
set_registered_name
(
"ContainsFoo"
);
TF_ASSERT_OK
(
ComponentTransformer
::
ApplyAll
(
&
component_spec
));
EXPECT_THAT
(
component_spec
,
test
::
EqualsProto
(
expected_spec
));
}
// Tests that a spec with builder "Bar" is changed to "ContainsBar".
TEST
(
SelectBestComponentTransformerTest
,
ChangeToContainsBar
)
{
ComponentSpec
component_spec
;
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"Bar"
);
ComponentSpec
expected_spec
=
component_spec
;
expected_spec
.
mutable_component_builder
()
->
set_registered_name
(
"ContainsBar"
);
TF_ASSERT_OK
(
ComponentTransformer
::
ApplyAll
(
&
component_spec
));
EXPECT_THAT
(
component_spec
,
test
::
EqualsProto
(
expected_spec
));
}
// Tests that a spec with builder "FooBar" causes a conflict.
TEST
(
SelectBestComponentTransformerTest
,
Conflict
)
{
ComponentSpec
component_spec
;
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"FooBar"
);
EXPECT_THAT
(
ComponentTransformer
::
ApplyAll
(
&
component_spec
),
test
::
IsErrorWithSubstr
(
"both think they should be dis-preferred"
));
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/sequence_backend.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_backend.h"
#include "dragnn/core/component_registry.h"
#include "tensorflow/core/platform/logging.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
std
::
function
<
int
(
int
,
int
,
int
)
>
SequenceBackend
::
GetStepLookupFunction
(
const
string
&
method
)
{
if
(
method
==
"reverse-char"
||
method
==
"reverse-token"
)
{
// Reverses the |index| in the sequence. We are agnostic to whether the
// input is a sequence of tokens or chars.
return
[
this
](
int
unused_batch_index
,
int
unused_beam_index
,
int
index
)
{
index
=
sequence_size_
-
index
-
1
;
return
index
>=
0
&&
index
<
sequence_size_
?
index
:
-
1
;
};
}
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Unknown step lookup function: "
<<
method
;
}
void
SequenceBackend
::
InitializeComponent
(
const
ComponentSpec
&
spec
)
{
name_
=
spec
.
name
();
}
void
SequenceBackend
::
InitializeData
(
const
std
::
vector
<
std
::
vector
<
const
TransitionState
*>>
&
parent_states
,
int
max_beam_size
,
InputBatchCache
*
input_data
)
{
// Store the |parent_states| for forwarding to downstream components.
parent_states_
=
parent_states
;
}
std
::
vector
<
std
::
vector
<
const
TransitionState
*>>
SequenceBackend
::
GetBeam
()
{
// Forward the states of the previous component.
return
parent_states_
;
}
int
SequenceBackend
::
GetSourceBeamIndex
(
int
current_index
,
int
batch
)
const
{
// Forward the |current_index| to the previous component.
return
current_index
;
}
int
SequenceBackend
::
GetBeamIndexAtStep
(
int
step
,
int
current_index
,
int
batch
)
const
{
// Always return 0 since there is only one beam.
return
0
;
}
std
::
vector
<
std
::
vector
<
ComponentTrace
>>
SequenceBackend
::
GetTraceProtos
()
const
{
// Return a single trace, since the beam and batch sizes are fixed at 1.
return
{{
ComponentTrace
()}};
}
string
SequenceBackend
::
Name
()
const
{
return
name_
;
}
int
SequenceBackend
::
BeamSize
()
const
{
return
1
;
}
int
SequenceBackend
::
BatchSize
()
const
{
return
1
;
}
bool
SequenceBackend
::
IsReady
()
const
{
return
true
;
}
bool
SequenceBackend
::
IsTerminal
()
const
{
return
true
;
}
void
SequenceBackend
::
FinalizeData
()
{}
void
SequenceBackend
::
ResetComponent
()
{}
void
SequenceBackend
::
InitializeTracing
()
{}
void
SequenceBackend
::
DisableTracing
()
{}
int
SequenceBackend
::
StepsTaken
(
int
batch_index
)
const
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Not supported"
;
}
bool
SequenceBackend
::
AdvanceFromPrediction
(
const
float
*
transition_matrix
,
int
num_items
,
int
num_actions
)
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Not supported"
;
}
void
SequenceBackend
::
AdvanceFromOracle
()
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Not supported"
;
}
std
::
vector
<
std
::
vector
<
std
::
vector
<
Label
>>>
SequenceBackend
::
GetOracleLabels
()
const
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Not supported"
;
}
int
SequenceBackend
::
GetFixedFeatures
(
std
::
function
<
int32
*
(
int
)
>
allocate_indices
,
std
::
function
<
int64
*
(
int
)
>
allocate_ids
,
std
::
function
<
float
*
(
int
)
>
allocate_weights
,
int
channel_id
)
const
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Not supported"
;
}
int
SequenceBackend
::
BulkGetFixedFeatures
(
const
BulkFeatureExtractor
&
extractor
)
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Not supported"
;
}
void
SequenceBackend
::
BulkEmbedFixedFeatures
(
int
batch_size_padding
,
int
num_steps_padding
,
int
output_array_size
,
const
vector
<
const
float
*>
&
per_channel_embeddings
,
float
*
embedding_output
)
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Not supported"
;
}
void
SequenceBackend
::
BulkEmbedDenseFixedFeatures
(
const
vector
<
const
float
*>
&
per_channel_embeddings
,
float
*
embedding_output
,
int
embedding_output_size
,
int
*
offset_array_output
,
int
offset_array_size
)
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Not supported"
;
}
int
SequenceBackend
::
BulkDenseFeatureSize
()
const
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Not supported"
;
}
std
::
vector
<
LinkFeatures
>
SequenceBackend
::
GetRawLinkFeatures
(
int
channel_id
)
const
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Not supported"
;
}
void
SequenceBackend
::
AddTranslatedLinkFeaturesToTrace
(
const
std
::
vector
<
LinkFeatures
>
&
features
,
int
channel_id
)
{
LOG
(
FATAL
)
<<
"["
<<
name_
<<
"] Not supported"
;
}
REGISTER_DRAGNN_COMPONENT
(
SequenceBackend
);
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/sequence_backend.h
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_SEQUENCE_BACKEND_H_
#define DRAGNN_RUNTIME_SEQUENCE_BACKEND_H_
#include <functional>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/interfaces/component.h"
#include "dragnn/core/interfaces/transition_state.h"
#include "dragnn/core/util/label.h"
#include "dragnn/protos/data.pb.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "syntaxnet/base.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// Runtime-only component backend for sequence-based models. This is not used
// at training time, and provides trivial implementations of most methods. This
// is intended to be used with alternative feature extraction approaches, such
// as SequenceExtractor.
class
SequenceBackend
:
public
dragnn
::
Component
{
public:
// Sets the size of the sequence in the current input.
void
SetSequenceSize
(
int
size
)
{
sequence_size_
=
size
;
}
// Implements dragnn::Component.
std
::
function
<
int
(
int
,
int
,
int
)
>
GetStepLookupFunction
(
const
string
&
method
)
override
;
void
InitializeComponent
(
const
ComponentSpec
&
spec
)
override
;
void
InitializeData
(
const
std
::
vector
<
std
::
vector
<
const
TransitionState
*>>
&
parent_states
,
int
max_beam_size
,
InputBatchCache
*
input_data
)
override
;
std
::
vector
<
std
::
vector
<
const
TransitionState
*>>
GetBeam
()
override
;
int
GetSourceBeamIndex
(
int
current_index
,
int
batch
)
const
override
;
int
GetBeamIndexAtStep
(
int
step
,
int
current_index
,
int
batch
)
const
override
;
std
::
vector
<
std
::
vector
<
ComponentTrace
>>
GetTraceProtos
()
const
override
;
string
Name
()
const
override
;
int
BeamSize
()
const
override
;
int
BatchSize
()
const
override
;
bool
IsReady
()
const
override
;
bool
IsTerminal
()
const
override
;
void
FinalizeData
()
override
;
void
ResetComponent
()
override
;
void
InitializeTracing
()
override
;
void
DisableTracing
()
override
;
// Not implemented, crashes when called.
int
StepsTaken
(
int
batch_index
)
const
override
;
// Not implemented, crashes when called.
bool
AdvanceFromPrediction
(
const
float
*
transition_matrix
,
int
num_items
,
int
num_actions
)
override
;
// Not implemented, crashes when called.
void
AdvanceFromOracle
()
override
;
// Not implemented, crashes when called.
std
::
vector
<
std
::
vector
<
std
::
vector
<
Label
>>>
GetOracleLabels
()
const
override
;
// Not implemented, crashes when called.
int
GetFixedFeatures
(
std
::
function
<
int32
*
(
int
)
>
allocate_indices
,
std
::
function
<
int64
*
(
int
)
>
allocate_ids
,
std
::
function
<
float
*
(
int
)
>
allocate_weights
,
int
channel_id
)
const
override
;
// Not implemented, crashes when called.
int
BulkGetFixedFeatures
(
const
BulkFeatureExtractor
&
extractor
)
override
;
// Not implemented, crashes when called.
void
BulkEmbedFixedFeatures
(
int
batch_size_padding
,
int
num_steps_padding
,
int
output_array_size
,
const
vector
<
const
float
*>
&
per_channel_embeddings
,
float
*
embedding_output
)
override
;
// Not implemented, crashes when called.
void
BulkEmbedDenseFixedFeatures
(
const
vector
<
const
float
*>
&
per_channel_embeddings
,
float
*
embedding_output
,
int
embedding_output_size
,
int
*
offset_array_output
,
int
offset_array_size
)
override
;
// Not implemented, crashes when called.
int
BulkDenseFeatureSize
()
const
override
;
// Not implemented, crashes when called.
std
::
vector
<
LinkFeatures
>
GetRawLinkFeatures
(
int
channel_id
)
const
override
;
// Not implemented, crashes when called.
void
AddTranslatedLinkFeaturesToTrace
(
const
std
::
vector
<
LinkFeatures
>
&
features
,
int
channel_id
)
override
;
private:
// Name of the component that this backend supports.
string
name_
;
// Size of the current input sequence.
int
sequence_size_
=
0
;
// Parent states passed to InitializeData(), and passed along in GetBeam().
std
::
vector
<
std
::
vector
<
const
TransitionState
*>>
parent_states_
;
};
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_SEQUENCE_BACKEND_H_
research/syntaxnet/dragnn/runtime/sequence_backend_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_backend.h"
#include "dragnn/components/util/bulk_feature_extractor.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/interfaces/transition_state.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Tests that the "reverse-*" step lookup functions ignore the batch and beam
// indices and return -1 if the sequence size was never set.
TEST
(
SequenceBackendTest
,
ReverseCharUninitialized
)
{
for
(
const
string
&
reverse_method
:
{
"reverse-char"
,
"reverse-token"
})
{
SequenceBackend
backend
;
const
std
::
function
<
int
(
int
,
int
,
int
)
>
reverse
=
backend
.
GetStepLookupFunction
(
reverse_method
);
EXPECT_EQ
(
reverse
(
0
,
0
,
0
),
-
1
);
EXPECT_EQ
(
reverse
(
1
,
1
,
1
),
-
1
);
EXPECT_EQ
(
reverse
(
-
1
,
-
1
,
-
1
),
-
1
);
EXPECT_EQ
(
reverse
(
0
,
0
,
9999
),
-
1
);
EXPECT_EQ
(
reverse
(
0
,
0
,
-
9999
),
-
1
);
}
}
// Tests that the "reverse-*" step lookup functions ignore the batch and beam
// indices and return the reverse of the step index w.r.t. the most recent call
// to SetSequenceSize().
TEST
(
SequenceBackendTest
,
ReverseCharAfterSetSequenceSize
)
{
for
(
const
string
&
reverse_method
:
{
"reverse-char"
,
"reverse-token"
})
{
SequenceBackend
backend
;
const
std
::
function
<
int
(
int
,
int
,
int
)
>
reverse
=
backend
.
GetStepLookupFunction
(
reverse_method
);
backend
.
SetSequenceSize
(
10
);
EXPECT_EQ
(
reverse
(
0
,
0
,
-
1
),
-
1
);
EXPECT_EQ
(
reverse
(
0
,
0
,
0
),
9
);
EXPECT_EQ
(
reverse
(
1
,
1
,
1
),
8
);
EXPECT_EQ
(
reverse
(
8
,
8
,
8
),
1
);
EXPECT_EQ
(
reverse
(
9
,
9
,
9
),
0
);
EXPECT_EQ
(
reverse
(
10
,
10
,
10
),
-
1
);
EXPECT_EQ
(
reverse
(
-
1
,
-
1
,
5
),
4
);
EXPECT_EQ
(
reverse
(
0
,
0
,
9999
),
-
1
);
EXPECT_EQ
(
reverse
(
0
,
0
,
-
9999
),
-
1
);
backend
.
SetSequenceSize
(
11
);
EXPECT_EQ
(
reverse
(
0
,
0
,
-
1
),
-
1
);
EXPECT_EQ
(
reverse
(
0
,
0
,
0
),
10
);
EXPECT_EQ
(
reverse
(
1
,
1
,
1
),
9
);
EXPECT_EQ
(
reverse
(
8
,
8
,
8
),
2
);
EXPECT_EQ
(
reverse
(
9
,
9
,
9
),
1
);
EXPECT_EQ
(
reverse
(
10
,
10
,
10
),
0
);
EXPECT_EQ
(
reverse
(
-
1
,
-
1
,
5
),
5
);
EXPECT_EQ
(
reverse
(
0
,
0
,
9999
),
-
1
);
EXPECT_EQ
(
reverse
(
0
,
0
,
-
9999
),
-
1
);
}
}
// Tests that the input beam is forwarded.
TEST
(
SequenceBackendTest
,
BeamForwarding
)
{
SequenceBackend
backend
;
const
TransitionState
*
parent_state
=
nullptr
;
parent_state
+=
1234
;
// arbitrary non-null pointer
const
std
::
vector
<
std
::
vector
<
const
TransitionState
*>>
parent_states
=
{
{
parent_state
}};
const
int
ignored_max_beam_size
=
999
;
InputBatchCache
*
ignored_input
=
nullptr
;
backend
.
InitializeData
(
parent_states
,
ignored_max_beam_size
,
ignored_input
);
EXPECT_EQ
(
backend
.
GetBeam
(),
parent_states
);
}
// Tests the accessors of the backend.
TEST
(
SequenceBackendTest
,
Accessors
)
{
SequenceBackend
backend
;
ComponentSpec
spec
;
spec
.
set_name
(
"foo"
);
backend
.
InitializeComponent
(
spec
);
EXPECT_EQ
(
backend
.
Name
(),
"foo"
);
EXPECT_EQ
(
backend
.
BeamSize
(),
1
);
EXPECT_EQ
(
backend
.
BatchSize
(),
1
);
EXPECT_TRUE
(
backend
.
IsReady
());
EXPECT_TRUE
(
backend
.
IsTerminal
());
}
// Tests the trivial mutators of the backend.
TEST
(
SequenceBackendTest
,
Mutators
)
{
SequenceBackend
backend
;
// These are NOPs and should not crash.
backend
.
FinalizeData
();
backend
.
ResetComponent
();
backend
.
InitializeTracing
();
backend
.
DisableTracing
();
}
// Tests the beam index accessors of the backend.
TEST
(
SequenceBackendTest
,
BeamIndex
)
{
SequenceBackend
backend
;
// This always returns the current_index (first arg).
EXPECT_EQ
(
backend
.
GetSourceBeamIndex
(
0
,
0
),
0
);
EXPECT_EQ
(
backend
.
GetSourceBeamIndex
(
1
,
2
),
1
);
EXPECT_EQ
(
backend
.
GetSourceBeamIndex
(
-
1
,
-
1
),
-
1
);
EXPECT_EQ
(
backend
.
GetSourceBeamIndex
(
10
,
99
),
10
);
// This always returns 0.
EXPECT_EQ
(
backend
.
GetBeamIndexAtStep
(
0
,
0
,
0
),
0
);
EXPECT_EQ
(
backend
.
GetBeamIndexAtStep
(
1
,
2
,
3
),
0
);
EXPECT_EQ
(
backend
.
GetBeamIndexAtStep
(
-
1
,
-
1
,
-
1
),
0
);
EXPECT_EQ
(
backend
.
GetBeamIndexAtStep
(
123
,
456
,
789
),
0
);
}
// Tests the that the backend produces a single empty trace.
TEST
(
SequenceBackendTest
,
Tracing
)
{
SequenceBackend
backend
;
const
ComponentTrace
empty_trace
;
const
auto
actual_traces
=
backend
.
GetTraceProtos
();
ASSERT_EQ
(
actual_traces
.
size
(),
1
);
ASSERT_EQ
(
actual_traces
[
0
].
size
(),
1
);
EXPECT_THAT
(
actual_traces
[
0
][
0
],
test
::
EqualsProto
(
empty_trace
));
}
// Tests the unsupported methods of the backend.
TEST
(
SequenceBackendTest
,
UnsupportedMethods
)
{
SequenceBackend
backend
;
EXPECT_DEATH
(
backend
.
StepsTaken
(
0
),
"Not supported"
);
EXPECT_DEATH
(
backend
.
AdvanceFromPrediction
(
nullptr
,
0
,
0
),
"Not supported"
);
EXPECT_DEATH
(
backend
.
AdvanceFromOracle
(),
"Not supported"
);
EXPECT_DEATH
(
backend
.
GetOracleLabels
(),
"Not supported"
);
EXPECT_DEATH
(
backend
.
GetFixedFeatures
(
nullptr
,
nullptr
,
nullptr
,
0
),
"Not supported"
);
EXPECT_DEATH
(
backend
.
BulkGetFixedFeatures
(
BulkFeatureExtractor
(
nullptr
,
nullptr
,
nullptr
)),
"Not supported"
);
EXPECT_DEATH
(
backend
.
BulkEmbedFixedFeatures
(
0
,
0
,
0
,
{},
nullptr
),
"Not supported"
);
EXPECT_DEATH
(
backend
.
GetRawLinkFeatures
(
0
),
"Not supported"
);
EXPECT_DEATH
(
backend
.
AddTranslatedLinkFeaturesToTrace
({},
0
),
"Not supported"
);
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/sequence_bulk_dynamic_component.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <string.h>
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "dragnn/runtime/bulk_network_unit.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/fixed_embeddings.h"
#include "dragnn/runtime/linked_embeddings.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/sequence_model.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/variable_store.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Sequence-based bulk version of DynamicComponent.
class
SequenceBulkDynamicComponent
:
public
Component
{
public:
// Implements Component.
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
override
;
tensorflow
::
Status
Evaluate
(
SessionState
*
session_state
,
ComputeSession
*
compute_session
,
ComponentTrace
*
component_trace
)
const
override
;
bool
Supports
(
const
ComponentSpec
&
component_spec
,
const
string
&
normalized_builder_name
)
const
override
;
bool
PreferredTo
(
const
Component
&
other
)
const
override
{
return
false
;
}
private:
// Evaluates all input features in the |state|, concatenates them into a
// matrix of inputs in the |network_states|, and returns the matrix.
Matrix
<
float
>
EvaluateInputs
(
const
SequenceModel
::
EvaluateState
&
state
,
const
NetworkStates
&
network_states
)
const
;
// Managers for input embeddings.
FixedEmbeddingManager
fixed_embedding_manager_
;
LinkedEmbeddingManager
linked_embedding_manager_
;
// Sequence-based model evaluator.
SequenceModel
sequence_model_
;
// Network unit for bulk inference.
std
::
unique_ptr
<
BulkNetworkUnit
>
bulk_network_unit_
;
// Concatenated input matrix.
LocalMatrixHandle
<
float
>
inputs_handle_
;
// Intermediate values used by sequence models.
SharedExtensionHandle
<
SequenceModel
::
EvaluateState
>
evaluate_state_handle_
;
};
bool
SequenceBulkDynamicComponent
::
Supports
(
const
ComponentSpec
&
component_spec
,
const
string
&
normalized_builder_name
)
const
{
// Require embedded fixed features.
for
(
const
FixedFeatureChannel
&
channel
:
component_spec
.
fixed_feature
())
{
if
(
channel
.
embedding_dim
()
<
0
)
return
false
;
}
// Require non-transformed and non-recurrent linked features.
// TODO(googleuser): Make SequenceLinks support transformed linked features?
for
(
const
LinkedFeatureChannel
&
channel
:
component_spec
.
linked_feature
())
{
if
(
channel
.
embedding_dim
()
>=
0
)
return
false
;
if
(
channel
.
source_component
()
==
component_spec
.
name
())
return
false
;
}
return
normalized_builder_name
==
"SequenceBulkDynamicComponent"
&&
SequenceModel
::
Supports
(
component_spec
);
}
// Returns the sum of the dimensions of all channels in the |manager|.
template
<
class
EmbeddingManager
>
size_t
SumEmbeddingDimensions
(
const
EmbeddingManager
&
manager
)
{
size_t
sum
=
0
;
for
(
size_t
i
=
0
;
i
<
manager
.
num_channels
();
++
i
)
{
sum
+=
manager
.
embedding_dim
(
i
);
}
return
sum
;
}
tensorflow
::
Status
SequenceBulkDynamicComponent
::
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
{
TF_RETURN_IF_ERROR
(
BulkNetworkUnit
::
CreateOrError
(
BulkNetworkUnit
::
GetClassName
(
component_spec
),
&
bulk_network_unit_
));
TF_RETURN_IF_ERROR
(
bulk_network_unit_
->
Initialize
(
component_spec
,
variable_store
,
network_state_manager
,
extension_manager
));
TF_RETURN_IF_ERROR
(
fixed_embedding_manager_
.
Reset
(
component_spec
,
variable_store
,
network_state_manager
));
TF_RETURN_IF_ERROR
(
linked_embedding_manager_
.
Reset
(
component_spec
,
variable_store
,
network_state_manager
));
const
size_t
concatenated_input_dim
=
SumEmbeddingDimensions
(
fixed_embedding_manager_
)
+
SumEmbeddingDimensions
(
linked_embedding_manager_
);
TF_RETURN_IF_ERROR
(
bulk_network_unit_
->
ValidateInputDimension
(
concatenated_input_dim
));
TF_RETURN_IF_ERROR
(
network_state_manager
->
AddLocal
(
concatenated_input_dim
,
&
inputs_handle_
));
TF_RETURN_IF_ERROR
(
sequence_model_
.
Initialize
(
component_spec
,
bulk_network_unit_
->
GetLogitsName
(),
&
fixed_embedding_manager_
,
&
linked_embedding_manager_
,
network_state_manager
));
extension_manager
->
GetShared
(
&
evaluate_state_handle_
);
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
SequenceBulkDynamicComponent
::
Evaluate
(
SessionState
*
session_state
,
ComputeSession
*
compute_session
,
ComponentTrace
*
component_trace
)
const
{
const
NetworkStates
&
network_states
=
session_state
->
network_states
;
SequenceModel
::
EvaluateState
&
state
=
session_state
->
extensions
.
Get
(
evaluate_state_handle_
);
TF_RETURN_IF_ERROR
(
sequence_model_
.
Preprocess
(
session_state
,
compute_session
,
&
state
));
const
Matrix
<
float
>
inputs
=
EvaluateInputs
(
state
,
network_states
);
TF_RETURN_IF_ERROR
(
bulk_network_unit_
->
Evaluate
(
inputs
,
session_state
));
return
sequence_model_
.
Predict
(
network_states
,
&
state
);
}
Matrix
<
float
>
SequenceBulkDynamicComponent
::
EvaluateInputs
(
const
SequenceModel
::
EvaluateState
&
state
,
const
NetworkStates
&
network_states
)
const
{
const
MutableMatrix
<
float
>
inputs
=
network_states
.
GetLocal
(
inputs_handle_
);
// Declared here for reuse in the loop below.
bool
is_out_of_bounds
=
false
;
Vector
<
float
>
embedding
;
// Handle forward and reverse iteration via a start index and increment.
int
target_index
=
sequence_model_
.
left_to_right
()
?
0
:
state
.
num_steps
-
1
;
const
int
target_increment
=
sequence_model_
.
left_to_right
()
?
1
:
-
1
;
for
(
size_t
step_index
=
0
;
step_index
<
state
.
num_steps
;
++
step_index
,
target_index
+=
target_increment
)
{
const
MutableVector
<
float
>
row
=
inputs
.
row
(
step_index
);
float
*
output
=
row
.
data
();
for
(
size_t
channel_id
=
0
;
channel_id
<
state
.
features
.
num_channels
();
++
channel_id
)
{
embedding
=
state
.
features
.
GetEmbedding
(
channel_id
,
target_index
);
memcpy
(
output
,
embedding
.
data
(),
embedding
.
size
()
*
sizeof
(
float
));
output
+=
embedding
.
size
();
}
for
(
size_t
channel_id
=
0
;
channel_id
<
state
.
links
.
num_channels
();
++
channel_id
)
{
state
.
links
.
Get
(
channel_id
,
target_index
,
&
embedding
,
&
is_out_of_bounds
);
memcpy
(
output
,
embedding
.
data
(),
embedding
.
size
()
*
sizeof
(
float
));
output
+=
embedding
.
size
();
}
DCHECK_EQ
(
output
,
row
.
end
());
}
return
Matrix
<
float
>
(
inputs
);
}
DRAGNN_RUNTIME_REGISTER_COMPONENT
(
SequenceBulkDynamicComponent
);
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/sequence_bulk_dynamic_component_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/bulk_network_unit.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/sequence_backend.h"
#include "dragnn/runtime/sequence_extractor.h"
#include "dragnn/runtime/sequence_linker.h"
#include "dragnn/runtime/sequence_predictor.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include <gmock/gmock.h>
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
using
::
testing
::
Return
;
constexpr
size_t
kNumSteps
=
50
;
constexpr
size_t
kFixedDim
=
11
;
constexpr
size_t
kFixedVocabularySize
=
123
;
constexpr
float
kFixedValue
=
0.5
;
constexpr
size_t
kLinkedDim
=
13
;
constexpr
float
kLinkedValue
=
1.25
;
constexpr
char
kPreviousComponentName
[]
=
"previous_component"
;
constexpr
char
kPreviousLayerName
[]
=
"previous_layer"
;
constexpr
char
kLogitsName
[]
=
"logits"
;
constexpr
size_t
kLogitsDim
=
kFixedDim
+
kLinkedDim
;
// Adds one to all inputs.
class
BulkAddOne
:
public
BulkNetworkUnit
{
public:
// Implements BulkNetworkUnit.
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
override
{
return
network_state_manager
->
AddLayer
(
kLogitsName
,
kLogitsDim
,
&
logits_handle_
);
}
tensorflow
::
Status
ValidateInputDimension
(
size_t
dimension
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
string
GetLogitsName
()
const
override
{
return
kLogitsName
;
}
tensorflow
::
Status
Evaluate
(
Matrix
<
float
>
inputs
,
SessionState
*
session_state
)
const
override
{
const
MutableMatrix
<
float
>
logits
=
session_state
->
network_states
.
GetLayer
(
logits_handle_
);
for
(
size_t
row
=
0
;
row
<
inputs
.
num_rows
();
++
row
)
{
for
(
size_t
column
=
0
;
column
<
inputs
.
num_columns
();
++
column
)
{
logits
.
row
(
row
)[
column
]
=
inputs
.
row
(
row
)[
column
]
+
1.0
;
}
}
return
tensorflow
::
Status
::
OK
();
}
private:
// Output logits.
LayerHandle
<
float
>
logits_handle_
;
};
DRAGNN_RUNTIME_REGISTER_BULK_NETWORK_UNIT
(
BulkAddOne
);
// A component that also prefers other but is triggered on the presence of a
// resource. This can be used to cause a component selection conflict.
class
ImTheWorst
:
public
Component
{
public:
// Implements Component.
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
Evaluate
(
SessionState
*
session_state
,
ComputeSession
*
compute_session
,
ComponentTrace
*
component_trace
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
bool
Supports
(
const
ComponentSpec
&
component_spec
,
const
string
&
normalized_builder_name
)
const
override
{
return
component_spec
.
resource_size
()
>
0
;
}
bool
PreferredTo
(
const
Component
&
other
)
const
override
{
return
false
;
}
};
DRAGNN_RUNTIME_REGISTER_COMPONENT
(
ImTheWorst
);
// Extractor that produces a sequence of zeros.
class
ExtractZeros
:
public
SequenceExtractor
{
public:
// Implements SequenceExtractor.
bool
Supports
(
const
FixedFeatureChannel
&
,
const
ComponentSpec
&
)
const
override
{
return
true
;
}
tensorflow
::
Status
Initialize
(
const
FixedFeatureChannel
&
,
const
ComponentSpec
&
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
GetIds
(
InputBatchCache
*
,
std
::
vector
<
int32
>
*
ids
)
const
override
{
ids
->
assign
(
kNumSteps
,
0
);
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR
(
ExtractZeros
);
// Linker that produces a sequence of zeros.
class
LinkZeros
:
public
SequenceLinker
{
public:
// Implements SequenceLinker.
bool
Supports
(
const
LinkedFeatureChannel
&
,
const
ComponentSpec
&
)
const
override
{
return
true
;
}
tensorflow
::
Status
Initialize
(
const
LinkedFeatureChannel
&
,
const
ComponentSpec
&
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
GetLinks
(
size_t
,
InputBatchCache
*
,
std
::
vector
<
int32
>
*
links
)
const
override
{
links
->
assign
(
kNumSteps
,
0
);
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER
(
LinkZeros
);
// Predictor that captures the logits.
class
CaptureLogits
:
public
SequencePredictor
{
public:
// Implements SequencePredictor.
bool
Supports
(
const
ComponentSpec
&
)
const
override
{
return
true
;
}
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
Predict
(
Matrix
<
float
>
logits
,
InputBatchCache
*
)
const
override
{
logits_
=
logits
;
return
tensorflow
::
Status
::
OK
();
}
// Returns the captured logits.
static
Matrix
<
float
>
GetCapturedLogits
()
{
return
logits_
;
}
private:
// Logits from the most recent call to Predict().
static
Matrix
<
float
>
logits_
;
};
Matrix
<
float
>
CaptureLogits
::
logits_
;
DRAGNN_RUNTIME_REGISTER_SEQUENCE_PREDICTOR
(
CaptureLogits
);
class
SequenceBulkDynamicComponentTest
:
public
NetworkTestBase
{
protected:
SequenceBulkDynamicComponentTest
()
{
EXPECT_CALL
(
compute_session_
,
GetInputBatchCache
())
.
WillRepeatedly
(
Return
(
&
input_
));
EXPECT_CALL
(
compute_session_
,
GetReadiedComponent
(
kTestComponentName
))
.
WillRepeatedly
(
Return
(
&
backend_
));
}
// Returns a spec that the network supports.
ComponentSpec
GetSupportedSpec
()
{
ComponentSpec
component_spec
;
component_spec
.
set_name
(
kTestComponentName
);
component_spec
.
set_num_actions
(
kLogitsDim
);
component_spec
.
mutable_network_unit
()
->
set_registered_name
(
"AddOne"
);
component_spec
.
mutable_backend
()
->
set_registered_name
(
"SequenceBackend"
);
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"SequenceBulkDynamicComponent"
);
auto
&
component_parameters
=
*
component_spec
.
mutable_component_builder
()
->
mutable_parameters
();
component_parameters
[
"sequence_extractors"
]
=
"ExtractZeros"
;
component_parameters
[
"sequence_linkers"
]
=
"LinkZeros"
;
component_parameters
[
"sequence_predictor"
]
=
"CaptureLogits"
;
FixedFeatureChannel
*
fixed_feature
=
component_spec
.
add_fixed_feature
();
fixed_feature
->
set_size
(
1
);
fixed_feature
->
set_embedding_dim
(
kFixedDim
);
fixed_feature
->
set_vocabulary_size
(
kFixedVocabularySize
);
LinkedFeatureChannel
*
linked_feature
=
component_spec
.
add_linked_feature
();
linked_feature
->
set_size
(
1
);
linked_feature
->
set_embedding_dim
(
-
1
);
linked_feature
->
set_source_component
(
kPreviousComponentName
);
linked_feature
->
set_source_layer
(
kPreviousLayerName
);
return
component_spec
;
}
// Creates a network unit, initializes it based on the |component_spec_text|,
// and evaluates it. On error, returns non-OK.
tensorflow
::
Status
Run
(
const
ComponentSpec
&
component_spec
)
{
AddComponent
(
kPreviousComponentName
);
AddLayer
(
kPreviousLayerName
,
kLinkedDim
);
AddComponent
(
kTestComponentName
);
AddFixedEmbeddingMatrix
(
0
,
kFixedVocabularySize
,
kFixedDim
,
kFixedValue
);
std
::
unique_ptr
<
Component
>
component
;
TF_RETURN_IF_ERROR
(
Component
::
CreateOrError
(
"SequenceBulkDynamicComponent"
,
&
component
));
TF_RETURN_IF_ERROR
(
component
->
Initialize
(
component_spec
,
&
variable_store_
,
&
network_state_manager_
,
&
extension_manager_
));
// Allocates network states for a few steps.
network_states_
.
Reset
(
&
network_state_manager_
);
StartComponent
(
kNumSteps
);
FillLayer
(
kPreviousComponentName
,
kPreviousLayerName
,
kLinkedValue
);
StartComponent
(
0
);
session_state_
.
extensions
.
Reset
(
&
extension_manager_
);
return
component
->
Evaluate
(
&
session_state_
,
&
compute_session_
,
nullptr
);
}
// Input batch injected into Evaluate() by default.
InputBatchCache
input_
;
// Backend injected into Evaluate().
SequenceBackend
backend_
;
};
// Tests that the supported spec is supported.
TEST_F
(
SequenceBulkDynamicComponentTest
,
Supported
)
{
const
ComponentSpec
component_spec
=
GetSupportedSpec
();
string
component_type
;
TF_ASSERT_OK
(
Component
::
Select
(
component_spec
,
&
component_type
));
EXPECT_EQ
(
component_type
,
"SequenceBulkDynamicComponent"
);
TF_ASSERT_OK
(
Run
(
component_spec
));
const
Matrix
<
float
>
logits
=
CaptureLogits
::
GetCapturedLogits
();
ASSERT_EQ
(
logits
.
num_rows
(),
kNumSteps
);
ASSERT_EQ
(
logits
.
num_columns
(),
kFixedDim
+
kLinkedDim
);
for
(
size_t
row
=
0
;
row
<
kNumSteps
;
++
row
)
{
size_t
column
=
0
;
for
(;
column
<
kFixedDim
;
++
column
)
{
EXPECT_EQ
(
logits
.
row
(
row
)[
column
],
kFixedValue
+
1.0
);
}
for
(;
column
<
kFixedDim
+
kLinkedDim
;
++
column
)
{
EXPECT_EQ
(
logits
.
row
(
row
)[
column
],
kLinkedValue
+
1.0
);
}
}
}
// Tests that links cannot be recurrent.
TEST_F
(
SequenceBulkDynamicComponentTest
,
ForbidRecurrences
)
{
ComponentSpec
component_spec
=
GetSupportedSpec
();
component_spec
.
mutable_linked_feature
(
0
)
->
set_source_component
(
kTestComponentName
);
string
component_type
;
EXPECT_THAT
(
Component
::
Select
(
component_spec
,
&
component_type
),
test
::
IsErrorWithSubstr
(
"Could not find a best spec for component"
));
}
// Tests that the component prefers others.
TEST_F
(
SequenceBulkDynamicComponentTest
,
PrefersOthers
)
{
ComponentSpec
component_spec
=
GetSupportedSpec
();
component_spec
.
add_resource
();
// Adding a resource triggers the ImTheWorst component, which also prefers
// itself and leads to a selection conflict.
string
component_type
;
EXPECT_THAT
(
Component
::
Select
(
component_spec
,
&
component_type
),
test
::
IsErrorWithSubstr
(
"both think they should be dis-preferred"
));
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/sequence_component_transformer.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <string>
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component_transformation.h"
#include "dragnn/runtime/sequence_extractor.h"
#include "dragnn/runtime/sequence_linker.h"
#include "dragnn/runtime/sequence_predictor.h"
#include "dragnn/runtime/transition_system_traits.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Returns true if the |component_spec| has recurrent links.
bool
IsRecurrent
(
const
ComponentSpec
&
component_spec
)
{
for
(
const
LinkedFeatureChannel
&
channel
:
component_spec
.
linked_feature
())
{
if
(
channel
.
source_component
()
==
component_spec
.
name
())
return
true
;
}
return
false
;
}
// Returns the sequence-based version of the |component_type| with specification
// |component_spec|, or an empty string if there is no sequence-based version.
string
GetSequenceComponentType
(
const
string
&
component_type
,
const
ComponentSpec
&
component_spec
)
{
// TODO(googleuser): Implement a SequenceDynamicComponent that can handle
// recurrent links. This may require changes to the NetworkUnit API.
static
const
char
*
kSupportedComponentTypes
[]
=
{
"BulkDynamicComponent"
,
//
"BulkLstmComponent"
,
//
"MyelinDynamicComponent"
,
//
};
for
(
const
char
*
supported_type
:
kSupportedComponentTypes
)
{
if
(
component_type
==
supported_type
)
{
return
tensorflow
::
strings
::
StrCat
(
"Sequence"
,
supported_type
);
}
}
// Also support non-recurrent DynamicComponents. The BulkDynamicComponent
// requires determinism, but the SequenceBulkDynamicComponent does not, so
// it's not sufficient to only upgrade from BulkDynamicComponent.
if
(
component_type
==
"DynamicComponent"
&&
!
IsRecurrent
(
component_spec
))
{
return
"SequenceBulkDynamicComponent"
;
}
return
string
();
}
// Returns the |status| but coerces NOT_FOUND to OK. Sets |found| to false iff
// the |status| was NOT_FOUND.
tensorflow
::
Status
AllowNotFound
(
const
tensorflow
::
Status
&
status
,
bool
*
found
)
{
*
found
=
status
.
code
()
!=
tensorflow
::
error
::
NOT_FOUND
;
return
*
found
?
status
:
tensorflow
::
Status
::
OK
();
}
// Transformer that checks whether a sequence-based component implementation
// could be used and, if compatible, modifies the ComponentSpec accordingly.
class
SequenceComponentTransformer
:
public
ComponentTransformer
{
public:
// Implements ComponentTransformer.
tensorflow
::
Status
Transform
(
const
string
&
component_type
,
ComponentSpec
*
component_spec
)
override
;
};
tensorflow
::
Status
SequenceComponentTransformer
::
Transform
(
const
string
&
component_type
,
ComponentSpec
*
component_spec
)
{
const
int
num_features
=
component_spec
->
fixed_feature_size
()
+
component_spec
->
linked_feature_size
();
if
(
num_features
==
0
)
return
tensorflow
::
Status
::
OK
();
// Look for supporting SequenceExtractors.
bool
found
=
false
;
string
extractor_types
;
for
(
const
FixedFeatureChannel
&
channel
:
component_spec
->
fixed_feature
())
{
string
type
;
TF_RETURN_IF_ERROR
(
AllowNotFound
(
SequenceExtractor
::
Select
(
channel
,
*
component_spec
,
&
type
),
&
found
));
if
(
!
found
)
return
tensorflow
::
Status
::
OK
();
tensorflow
::
strings
::
StrAppend
(
&
extractor_types
,
type
,
","
);
}
if
(
!
extractor_types
.
empty
())
extractor_types
.
pop_back
();
// remove comma
// Look for supporting SequenceLinkers.
string
linker_types
;
for
(
const
LinkedFeatureChannel
&
channel
:
component_spec
->
linked_feature
())
{
string
type
;
TF_RETURN_IF_ERROR
(
AllowNotFound
(
SequenceLinker
::
Select
(
channel
,
*
component_spec
,
&
type
),
&
found
));
if
(
!
found
)
return
tensorflow
::
Status
::
OK
();
tensorflow
::
strings
::
StrAppend
(
&
linker_types
,
type
,
","
);
}
if
(
!
linker_types
.
empty
())
linker_types
.
pop_back
();
// remove comma
// Look for a supporting SequencePredictor, if predictions are necessary.
string
predictor_type
;
if
(
!
TransitionSystemTraits
(
*
component_spec
).
is_deterministic
)
{
TF_RETURN_IF_ERROR
(
AllowNotFound
(
SequencePredictor
::
Select
(
*
component_spec
,
&
predictor_type
),
&
found
));
if
(
!
found
)
return
tensorflow
::
Status
::
OK
();
}
// Look for a supporting sequence-based component type.
const
string
sequence_component_type
=
GetSequenceComponentType
(
component_type
,
*
component_spec
);
if
(
sequence_component_type
.
empty
())
return
tensorflow
::
Status
::
OK
();
// Success; make modifications.
component_spec
->
mutable_backend
()
->
set_registered_name
(
"SequenceBackend"
);
RegisteredModuleSpec
*
builder
=
component_spec
->
mutable_component_builder
();
builder
->
set_registered_name
(
sequence_component_type
);
(
*
builder
->
mutable_parameters
())[
"sequence_extractors"
]
=
extractor_types
;
(
*
builder
->
mutable_parameters
())[
"sequence_linkers"
]
=
linker_types
;
(
*
builder
->
mutable_parameters
())[
"sequence_predictor"
]
=
predictor_type
;
return
tensorflow
::
Status
::
OK
();
}
DRAGNN_RUNTIME_REGISTER_COMPONENT_TRANSFORMER
(
SequenceComponentTransformer
);
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/sequence_component_transformer_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <memory>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component_transformation.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/sequence_extractor.h"
#include "dragnn/runtime/sequence_linker.h"
#include "dragnn/runtime/sequence_predictor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Arbitrary supported component type.
constexpr
char
kSupportedComponentType
[]
=
"MyelinDynamicComponent"
;
// Sequence-based version of the component type.
constexpr
char
kTransformedComponentType
[]
=
"SequenceMyelinDynamicComponent"
;
// Trivial extractor that supports components named "supported".
class
SupportIfNamedSupportedExtractor
:
public
SequenceExtractor
{
public:
// Implements SequenceExtractor.
bool
Supports
(
const
FixedFeatureChannel
&
,
const
ComponentSpec
&
component_spec
)
const
override
{
return
component_spec
.
name
()
==
"supported"
;
}
tensorflow
::
Status
Initialize
(
const
FixedFeatureChannel
&
,
const
ComponentSpec
&
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
GetIds
(
InputBatchCache
*
,
std
::
vector
<
int32
>
*
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR
(
SupportIfNamedSupportedExtractor
);
// Trivial extractor that supports components if they have a resource. This is
// used to generate a "multiple supported extractors" conflict.
class
SupportIfHasResourcesExtractor
:
public
SequenceExtractor
{
public:
// Implements SequenceExtractor.
bool
Supports
(
const
FixedFeatureChannel
&
,
const
ComponentSpec
&
component_spec
)
const
override
{
return
component_spec
.
resource_size
()
>
0
;
}
tensorflow
::
Status
Initialize
(
const
FixedFeatureChannel
&
,
const
ComponentSpec
&
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
GetIds
(
InputBatchCache
*
,
std
::
vector
<
int32
>
*
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR
(
SupportIfHasResourcesExtractor
);
// Trivial linker that supports components named "supported".
class
SupportIfNamedSupportedLinker
:
public
SequenceLinker
{
public:
// Implements SequenceLinker.
bool
Supports
(
const
LinkedFeatureChannel
&
,
const
ComponentSpec
&
component_spec
)
const
override
{
return
component_spec
.
name
()
==
"supported"
;
}
tensorflow
::
Status
Initialize
(
const
LinkedFeatureChannel
&
,
const
ComponentSpec
&
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
GetLinks
(
size_t
,
InputBatchCache
*
,
std
::
vector
<
int32
>
*
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER
(
SupportIfNamedSupportedLinker
);
// Trivial predictor that supports components named "supported".
class
SupportIfNamedSupportedPredictor
:
public
SequencePredictor
{
public:
// Implements SequencePredictor.
bool
Supports
(
const
ComponentSpec
&
component_spec
)
const
override
{
return
component_spec
.
name
()
==
"supported"
;
}
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
Predict
(
Matrix
<
float
>
,
InputBatchCache
*
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_PREDICTOR
(
SupportIfNamedSupportedPredictor
);
// Returns a ComponentSpec that is supported by the transformer.
ComponentSpec
MakeSupportedSpec
()
{
ComponentSpec
component_spec
;
component_spec
.
set_name
(
"supported"
);
component_spec
.
set_num_actions
(
10
);
component_spec
.
add_fixed_feature
();
component_spec
.
add_fixed_feature
();
component_spec
.
add_linked_feature
();
component_spec
.
add_linked_feature
();
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
kSupportedComponentType
);
return
component_spec
;
}
// Tests that a compatible spec is modified to use a new backend and component
// builder with SequenceExtractors, SequenceLinkers, and SequencePredictor.
TEST
(
SequenceComponentTransformerTest
,
Compatible
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
ComponentSpec
modified_spec
=
component_spec
;
modified_spec
.
mutable_backend
()
->
set_registered_name
(
"SequenceBackend"
);
modified_spec
.
mutable_component_builder
()
->
set_registered_name
(
kTransformedComponentType
);
modified_spec
.
mutable_component_builder
()
->
mutable_parameters
()
->
insert
(
{
"sequence_extractors"
,
"SupportIfNamedSupportedExtractor,SupportIfNamedSupportedExtractor"
});
modified_spec
.
mutable_component_builder
()
->
mutable_parameters
()
->
insert
(
{
"sequence_linkers"
,
"SupportIfNamedSupportedLinker,SupportIfNamedSupportedLinker"
});
modified_spec
.
mutable_component_builder
()
->
mutable_parameters
()
->
insert
(
{
"sequence_predictor"
,
"SupportIfNamedSupportedPredictor"
});
TF_ASSERT_OK
(
ComponentTransformer
::
ApplyAll
(
&
component_spec
));
EXPECT_THAT
(
component_spec
,
test
::
EqualsProto
(
modified_spec
));
}
// Tests that a compatible deterministic spec is modified to use a new backend
// and component builder with SequenceExtractors and SequenceLinkers only.
TEST
(
SequenceComponentTransformerTest
,
CompatibleNoPredictor
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
set_num_actions
(
1
);
ComponentSpec
modified_spec
=
component_spec
;
modified_spec
.
mutable_backend
()
->
set_registered_name
(
"SequenceBackend"
);
modified_spec
.
mutable_component_builder
()
->
set_registered_name
(
kTransformedComponentType
);
modified_spec
.
mutable_component_builder
()
->
mutable_parameters
()
->
insert
(
{
"sequence_extractors"
,
"SupportIfNamedSupportedExtractor,SupportIfNamedSupportedExtractor"
});
modified_spec
.
mutable_component_builder
()
->
mutable_parameters
()
->
insert
(
{
"sequence_linkers"
,
"SupportIfNamedSupportedLinker,SupportIfNamedSupportedLinker"
});
modified_spec
.
mutable_component_builder
()
->
mutable_parameters
()
->
insert
(
{
"sequence_predictor"
,
""
});
TF_ASSERT_OK
(
ComponentTransformer
::
ApplyAll
(
&
component_spec
));
EXPECT_THAT
(
component_spec
,
test
::
EqualsProto
(
modified_spec
));
}
// Tests that a ComponentSpec with no features is incompatible.
TEST
(
SequenceComponentTransformerTest
,
IncompatibleNoFeatures
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
clear_fixed_feature
();
component_spec
.
clear_linked_feature
();
const
ComponentSpec
unchanged_spec
=
component_spec
;
TF_ASSERT_OK
(
ComponentTransformer
::
ApplyAll
(
&
component_spec
));
EXPECT_THAT
(
component_spec
,
test
::
EqualsProto
(
unchanged_spec
));
}
// Tests that a ComponentSpec with the wrong component builder is incompatible.
TEST
(
SequenceComponentTransformerTest
,
IncompatibleComponentBuilder
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"bad"
);
const
ComponentSpec
unchanged_spec
=
component_spec
;
TF_ASSERT_OK
(
ComponentTransformer
::
ApplyAll
(
&
component_spec
));
EXPECT_THAT
(
component_spec
,
test
::
EqualsProto
(
unchanged_spec
));
}
// Tests that a ComponentSpec is incompatible if it is not supported by any
// SequenceExtractor.
TEST
(
SequenceComponentTransformerTest
,
IncompatibleNoSupportingSequenceExtractor
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
set_name
(
"bad"
);
const
ComponentSpec
unchanged_spec
=
component_spec
;
TF_ASSERT_OK
(
ComponentTransformer
::
ApplyAll
(
&
component_spec
));
EXPECT_THAT
(
component_spec
,
test
::
EqualsProto
(
unchanged_spec
));
}
// Tests that a ComponentSpec fails if multiple SequenceExtractors support it.
TEST
(
SequenceComponentTransformerTest
,
FailIfMultipleSupportingSequenceExtractors
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
add_resource
();
// triggers SupportIfHasResourcesExtractor
EXPECT_THAT
(
ComponentTransformer
::
ApplyAll
(
&
component_spec
),
test
::
IsErrorWithSubstr
(
"Multiple SequenceExtractors support channel"
));
}
// Tests that a DynamicComponent is not upgraded if it is recurrent.
TEST
(
SequenceComponentTransformerTest
,
RecurrentDynamicComponent
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"DynamicComponent"
);
component_spec
.
mutable_linked_feature
(
0
)
->
set_source_component
(
component_spec
.
name
());
const
ComponentSpec
unchanged_spec
=
component_spec
;
TF_ASSERT_OK
(
ComponentTransformer
::
ApplyAll
(
&
component_spec
));
EXPECT_THAT
(
component_spec
,
test
::
EqualsProto
(
unchanged_spec
));
}
// Tests that a DynamicComponent is upgraded to SequenceBulkDynamicComponent if
// it is non-recurrent.
TEST
(
SequenceComponentTransformerTest
,
NonRecurrentDynamicComponent
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"DynamicComponent"
);
ComponentSpec
modified_spec
=
component_spec
;
modified_spec
.
mutable_backend
()
->
set_registered_name
(
"SequenceBackend"
);
modified_spec
.
mutable_component_builder
()
->
set_registered_name
(
"SequenceBulkDynamicComponent"
);
modified_spec
.
mutable_component_builder
()
->
mutable_parameters
()
->
insert
(
{
"sequence_extractors"
,
"SupportIfNamedSupportedExtractor,SupportIfNamedSupportedExtractor"
});
modified_spec
.
mutable_component_builder
()
->
mutable_parameters
()
->
insert
(
{
"sequence_linkers"
,
"SupportIfNamedSupportedLinker,SupportIfNamedSupportedLinker"
});
modified_spec
.
mutable_component_builder
()
->
mutable_parameters
()
->
insert
(
{
"sequence_predictor"
,
"SupportIfNamedSupportedPredictor"
});
TF_ASSERT_OK
(
ComponentTransformer
::
ApplyAll
(
&
component_spec
));
EXPECT_THAT
(
component_spec
,
test
::
EqualsProto
(
modified_spec
));
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/sequence_extractor.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_extractor.h"
#include "tensorflow/core/lib/core/errors.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
tensorflow
::
Status
SequenceExtractor
::
Select
(
const
FixedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
,
string
*
name
)
{
string
supporting_name
;
for
(
const
Registry
::
Registrar
*
registrar
=
registry
()
->
components
;
registrar
!=
nullptr
;
registrar
=
registrar
->
next
())
{
Factory
*
factory_function
=
registrar
->
object
();
std
::
unique_ptr
<
SequenceExtractor
>
current_extractor
(
factory_function
());
if
(
!
current_extractor
->
Supports
(
channel
,
component_spec
))
continue
;
if
(
!
supporting_name
.
empty
())
{
return
tensorflow
::
errors
::
Internal
(
"Multiple SequenceExtractors support channel "
,
channel
.
ShortDebugString
(),
" of ComponentSpec ("
,
supporting_name
,
" and "
,
registrar
->
name
(),
"): "
,
component_spec
.
ShortDebugString
());
}
supporting_name
=
registrar
->
name
();
}
if
(
supporting_name
.
empty
())
{
return
tensorflow
::
errors
::
NotFound
(
"No SequenceExtractor supports channel "
,
channel
.
ShortDebugString
(),
" of ComponentSpec: "
,
component_spec
.
ShortDebugString
());
}
// Success; make modifications.
*
name
=
supporting_name
;
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
SequenceExtractor
::
New
(
const
string
&
name
,
const
FixedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
,
std
::
unique_ptr
<
SequenceExtractor
>
*
extractor
)
{
std
::
unique_ptr
<
SequenceExtractor
>
matching_extractor
;
TF_RETURN_IF_ERROR
(
SequenceExtractor
::
CreateOrError
(
name
,
&
matching_extractor
));
TF_RETURN_IF_ERROR
(
matching_extractor
->
Initialize
(
channel
,
component_spec
));
// Success; make modifications.
*
extractor
=
std
::
move
(
matching_extractor
);
return
tensorflow
::
Status
::
OK
();
}
}
// namespace runtime
}
// namespace dragnn
REGISTER_SYNTAXNET_CLASS_REGISTRY
(
"DRAGNN Runtime Sequence Extractor"
,
dragnn
::
runtime
::
SequenceExtractor
);
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/sequence_extractor.h
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_SEQUENCE_EXTRACTOR_H_
#define DRAGNN_RUNTIME_SEQUENCE_EXTRACTOR_H_
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/protos/spec.pb.h"
#include "syntaxnet/base.h"
#include "syntaxnet/registry.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// Interface for feature extraction for sequence inputs.
//
// This extractor can be used to avoid ComputeSession overhead in simple cases;
// for example, extracting a sequence of character or word IDs for an LSTM.
class
SequenceExtractor
:
public
RegisterableClass
<
SequenceExtractor
>
{
public:
// Sets |extractor| to an instance of the subclass named |name| initialized
// from the |channel| of the |component_spec|. On error, returns non-OK and
// modifies nothing.
static
tensorflow
::
Status
New
(
const
string
&
name
,
const
FixedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
,
std
::
unique_ptr
<
SequenceExtractor
>
*
extractor
);
SequenceExtractor
(
const
SequenceExtractor
&
)
=
delete
;
SequenceExtractor
&
operator
=
(
const
SequenceExtractor
&
)
=
delete
;
virtual
~
SequenceExtractor
()
=
default
;
// Sets |name| to the registered name of the SequenceExtractor that supports
// the |channel| of the |component_spec|. On error, returns non-OK and
// modifies nothing. The returned statuses include:
// * OK: If a supporting SequenceExtractor was found.
// * INTERNAL: If an error occurred while searching for a compatible match.
// * NOT_FOUND: If the search was error-free, but no compatible match was
// found.
static
tensorflow
::
Status
Select
(
const
FixedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
,
string
*
name
);
// Overwrites |ids| with the sequence of features extracted from the |input|.
// On error, returns non-OK.
virtual
tensorflow
::
Status
GetIds
(
InputBatchCache
*
input
,
std
::
vector
<
int32
>
*
ids
)
const
=
0
;
protected:
SequenceExtractor
()
=
default
;
private:
// Helps prevent use of the Create() method; use New() instead.
using
RegisterableClass
<
SequenceExtractor
>::
Create
;
// Returns true if this supports the |channel| of the |component_spec|.
// Implementations must coordinate to ensure that at most one supports any
// given |component_spec|.
virtual
bool
Supports
(
const
FixedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
)
const
=
0
;
// Initializes this from the |channel| of the |component_spec|. On error,
// returns non-OK.
virtual
tensorflow
::
Status
Initialize
(
const
FixedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
)
=
0
;
};
}
// namespace runtime
}
// namespace dragnn
DECLARE_SYNTAXNET_CLASS_REGISTRY
(
"DRAGNN Runtime Sequence Extractor"
,
dragnn
::
runtime
::
SequenceExtractor
);
}
// namespace syntaxnet
#define DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR(subclass) \
REGISTER_SYNTAXNET_CLASS_COMPONENT( \
::syntaxnet::dragnn::runtime::SequenceExtractor, #subclass, subclass)
#endif // DRAGNN_RUNTIME_SEQUENCE_EXTRACTOR_H_
research/syntaxnet/dragnn/runtime/sequence_extractor_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_extractor.h"
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Supports components named "success" and initializes successfully.
class
Success
:
public
SequenceExtractor
{
public:
// Implements SequenceExtractor.
bool
Supports
(
const
FixedFeatureChannel
&
,
const
ComponentSpec
&
component_spec
)
const
override
{
return
component_spec
.
name
()
==
"success"
;
}
tensorflow
::
Status
Initialize
(
const
FixedFeatureChannel
&
,
const
ComponentSpec
&
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
GetIds
(
InputBatchCache
*
,
std
::
vector
<
int32
>
*
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR
(
Success
);
// Supports components named "failure" and fails to initialize.
class
Failure
:
public
SequenceExtractor
{
public:
// Implements SequenceExtractor.
bool
Supports
(
const
FixedFeatureChannel
&
,
const
ComponentSpec
&
component_spec
)
const
override
{
return
component_spec
.
name
()
==
"failure"
;
}
tensorflow
::
Status
Initialize
(
const
FixedFeatureChannel
&
,
const
ComponentSpec
&
)
override
{
return
tensorflow
::
errors
::
Internal
(
"Boom!"
);
}
tensorflow
::
Status
GetIds
(
InputBatchCache
*
,
std
::
vector
<
int32
>
*
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR
(
Failure
);
// Supports components named "duplicate" and initializes successfully.
class
Duplicate
:
public
SequenceExtractor
{
public:
// Implements SequenceExtractor.
bool
Supports
(
const
FixedFeatureChannel
&
,
const
ComponentSpec
&
component_spec
)
const
override
{
return
component_spec
.
name
()
==
"duplicate"
;
}
tensorflow
::
Status
Initialize
(
const
FixedFeatureChannel
&
,
const
ComponentSpec
&
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
GetIds
(
InputBatchCache
*
,
std
::
vector
<
int32
>
*
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR
(
Duplicate
);
// Duplicate of the above.
using
Duplicate2
=
Duplicate
;
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR
(
Duplicate2
);
// Tests that a component can be successfully created.
TEST
(
SequenceExtractorTest
,
Success
)
{
string
name
;
std
::
unique_ptr
<
SequenceExtractor
>
extractor
;
ComponentSpec
component_spec
;
component_spec
.
set_name
(
"success"
);
TF_ASSERT_OK
(
SequenceExtractor
::
Select
({},
component_spec
,
&
name
));
ASSERT_EQ
(
name
,
"Success"
);
TF_EXPECT_OK
(
SequenceExtractor
::
New
(
name
,
{},
component_spec
,
&
extractor
));
EXPECT_NE
(
extractor
,
nullptr
);
}
// Tests that errors in Initialize() are reported.
TEST
(
SequenceExtractorTest
,
FailToInitialize
)
{
string
name
;
std
::
unique_ptr
<
SequenceExtractor
>
extractor
;
ComponentSpec
component_spec
;
component_spec
.
set_name
(
"failure"
);
TF_ASSERT_OK
(
SequenceExtractor
::
Select
({},
component_spec
,
&
name
));
EXPECT_EQ
(
name
,
"Failure"
);
EXPECT_THAT
(
SequenceExtractor
::
New
(
name
,
{},
component_spec
,
&
extractor
),
test
::
IsErrorWithSubstr
(
"Boom!"
));
EXPECT_EQ
(
extractor
,
nullptr
);
}
// Tests that unsupported specs are reported as NOT_FOUND errors.
TEST
(
SequenceExtractorTest
,
UnsupportedSpec
)
{
string
name
=
"not overwritten"
;
ComponentSpec
component_spec
;
component_spec
.
set_name
(
"unsupported"
);
EXPECT_THAT
(
SequenceExtractor
::
Select
({},
component_spec
,
&
name
),
test
::
IsErrorWithCodeAndSubstr
(
tensorflow
::
error
::
NOT_FOUND
,
"No SequenceExtractor supports channel"
));
EXPECT_EQ
(
name
,
"not overwritten"
);
}
// Tests that unsupported subclass names are reported as errors.
TEST
(
SequenceExtractorTest
,
UnsupportedSubclass
)
{
std
::
unique_ptr
<
SequenceExtractor
>
extractor
;
ComponentSpec
component_spec
;
EXPECT_THAT
(
SequenceExtractor
::
New
(
"Unsupported"
,
{},
component_spec
,
&
extractor
),
test
::
IsErrorWithSubstr
(
"Unknown DRAGNN Runtime Sequence Extractor"
));
EXPECT_EQ
(
extractor
,
nullptr
);
}
// Tests that multiple supporting extractors are reported as INTERNAL errors.
TEST
(
SequenceExtractorTest
,
Duplicate
)
{
string
name
=
"not overwritten"
;
ComponentSpec
component_spec
;
component_spec
.
set_name
(
"duplicate"
);
EXPECT_THAT
(
SequenceExtractor
::
Select
({},
component_spec
,
&
name
),
test
::
IsErrorWithCodeAndSubstr
(
tensorflow
::
error
::
INTERNAL
,
"Multiple SequenceExtractors support channel"
));
EXPECT_EQ
(
name
,
"not overwritten"
);
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/sequence_features.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_features.h"
#include <utility>
#include "tensorflow/core/lib/core/errors.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
tensorflow
::
Status
SequenceFeatureManager
::
Reset
(
const
FixedEmbeddingManager
*
fixed_embedding_manager
,
const
ComponentSpec
&
component_spec
,
const
std
::
vector
<
string
>
&
sequence_extractor_types
)
{
const
size_t
num_channels
=
fixed_embedding_manager
->
channel_configs_
.
size
();
if
(
component_spec
.
fixed_feature_size
()
!=
num_channels
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Channel mismatch between FixedEmbeddingManager ("
,
num_channels
,
") and ComponentSpec ("
,
component_spec
.
fixed_feature_size
(),
")"
);
}
if
(
sequence_extractor_types
.
size
()
!=
num_channels
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Channel mismatch between FixedEmbeddingManager ("
,
num_channels
,
") and SequenceExtractors ("
,
sequence_extractor_types
.
size
(),
")"
);
}
for
(
const
FixedFeatureChannel
&
channel
:
component_spec
.
fixed_feature
())
{
if
(
channel
.
size
()
>
1
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Multi-embedding fixed features are not supported for channel: "
,
channel
.
ShortDebugString
());
}
}
std
::
vector
<
ChannelConfig
>
local_configs
;
// avoid modification on error
for
(
size_t
channel_id
=
0
;
channel_id
<
num_channels
;
++
channel_id
)
{
local_configs
.
emplace_back
();
ChannelConfig
&
channel_config
=
local_configs
.
back
();
const
FixedEmbeddingManager
::
ChannelConfig
&
wrapped_config
=
fixed_embedding_manager
->
channel_configs_
[
channel_id
];
channel_config
.
is_embedded
=
wrapped_config
.
is_embedded
;
channel_config
.
embedding_matrix
=
wrapped_config
.
embedding_matrix
;
TF_RETURN_IF_ERROR
(
SequenceExtractor
::
New
(
sequence_extractor_types
[
channel_id
],
component_spec
.
fixed_feature
(
channel_id
),
component_spec
,
&
channel_config
.
extractor
));
}
// Success; make modifications.
zeros_
=
fixed_embedding_manager
->
zeros_
.
view
();
channel_configs_
=
std
::
move
(
local_configs
);
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
SequenceFeatures
::
Reset
(
const
SequenceFeatureManager
*
manager
,
InputBatchCache
*
input
)
{
manager_
=
manager
;
zeros_
=
manager
->
zeros_
;
num_channels_
=
manager
->
channel_configs_
.
size
();
num_steps_
=
0
;
// Make sure |channels_| is big enough. Note that |channels_| never shrinks,
// so the Channel.ids sub-vector is never deallocated.
if
(
num_channels_
>
channels_
.
size
())
channels_
.
resize
(
num_channels_
);
for
(
int
channel_id
=
0
;
channel_id
<
num_channels_
;
++
channel_id
)
{
Channel
&
channel
=
channels_
[
channel_id
];
const
SequenceFeatureManager
::
ChannelConfig
&
channel_config
=
manager
->
channel_configs_
[
channel_id
];
channel
.
embedding_matrix
=
channel_config
.
embedding_matrix
;
TF_RETURN_IF_ERROR
(
channel_config
.
extractor
->
GetIds
(
input
,
&
channel
.
ids
));
if
(
channel_id
==
0
)
{
num_steps_
=
channel
.
ids
.
size
();
}
else
if
(
channel
.
ids
.
size
()
!=
num_steps_
)
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Inconsistent feature sequence lengths at channel ID "
,
channel_id
,
": got "
,
channel
.
ids
.
size
(),
" but expected "
,
num_steps_
);
}
}
return
tensorflow
::
Status
::
OK
();
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/sequence_features.h
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Utils for configuring and extracting fixed embeddings for sequence-based
// models. Analogous to FixedEmbeddingManager and FixedEmbeddings, but uses
// SequenceExtractor instead of ComputeSession.
#ifndef DRAGNN_RUNTIME_SEQUENCE_FEATURES_H_
#define DRAGNN_RUNTIME_SEQUENCE_FEATURES_H_
#include <stddef.h>
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/fixed_embeddings.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/sequence_extractor.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// Manager for fixed embeddings for sequence-based models. This is a wrapper
// around the FixedEmbeddingManager.
class
SequenceFeatureManager
{
public:
// Creates an empty manager.
SequenceFeatureManager
()
=
default
;
// Resets this to wrap the |fixed_embedding_manager|, which must outlive this.
// The |sequence_extractor_types| should name one SequenceExtractor subclass
// per channel; e.g., "SyntaxNetCharacterSequenceExtractor". This initializes
// each SequenceExtractor from the |component_spec|. On error, returns non-OK
// and does not modify this.
tensorflow
::
Status
Reset
(
const
FixedEmbeddingManager
*
fixed_embedding_manager
,
const
ComponentSpec
&
component_spec
,
const
std
::
vector
<
string
>
&
sequence_extractor_types
);
// Accessors.
size_t
num_channels
()
const
{
return
channel_configs_
.
size
();
}
private:
friend
class
SequenceFeatures
;
// Configuration for a single fixed embedding channel.
struct
ChannelConfig
{
// Whether this channel is embedded.
bool
is_embedded
=
true
;
// Embedding matrix of this channel. Only used if |is_embedded| is true.
Matrix
<
float
>
embedding_matrix
;
// Extractor for sequences of feature IDs.
std
::
unique_ptr
<
SequenceExtractor
>
extractor
;
};
// Array of zeros that can be substituted for missing feature IDs. This is a
// reference to the corresponding array in the FixedEmbeddingManager.
AlignedView
zeros_
;
// Ordered list of configurations for each channel.
std
::
vector
<
ChannelConfig
>
channel_configs_
;
};
// A set of fixed embeddings for a sequence-based model. Configured by a
// SequenceFeatureManager.
class
SequenceFeatures
{
public:
// Creates an empty set of embeddings.
SequenceFeatures
()
=
default
;
// Resets this to the sequences of fixed features managed by the |manager| on
// the |input|. The |manager| must live until this is destroyed or Reset(),
// and should not be modified during that time. On error, returns non-OK.
tensorflow
::
Status
Reset
(
const
SequenceFeatureManager
*
manager
,
InputBatchCache
*
input
);
// Returns the feature ID or embedding for the |target_index|'th element of
// the |channel_id|'th channel. Each method is only valid for a non-embedded
// or embedded channel, respectively.
int32
GetId
(
size_t
channel_id
,
size_t
target_index
)
const
;
Vector
<
float
>
GetEmbedding
(
size_t
channel_id
,
size_t
target_index
)
const
;
// Accessors.
size_t
num_channels
()
const
{
return
num_channels_
;
}
size_t
num_steps
()
const
{
return
num_steps_
;
}
private:
// Data associated with a single fixed embedding channel.
struct
Channel
{
// Embedding matrix of this channel. Only used for embedded channels.
Matrix
<
float
>
embedding_matrix
;
// Feature IDs for each step.
std
::
vector
<
int32
>
ids
;
};
// Manager from the most recent Reset().
const
SequenceFeatureManager
*
manager_
=
nullptr
;
// Zero vector from the most recent Reset().
AlignedView
zeros_
;
// Number of channels and steps from the most recent Reset().
size_t
num_channels_
=
0
;
size_t
num_steps_
=
0
;
// Ordered list of fixed embedding channels. This may contain more than
// |num_channels_| entries, to avoid deallocation/reallocation cycles, but
// only the first |num_channels_| entries are valid.
std
::
vector
<
Channel
>
channels_
;
};
// Implementation details below.
inline
int32
SequenceFeatures
::
GetId
(
size_t
channel_id
,
size_t
target_index
)
const
{
DCHECK_LT
(
channel_id
,
num_channels
());
DCHECK_LT
(
target_index
,
num_steps
());
DCHECK
(
!
manager_
->
channel_configs_
[
channel_id
].
is_embedded
);
const
Channel
&
channel
=
channels_
[
channel_id
];
return
channel
.
ids
[
target_index
];
}
inline
Vector
<
float
>
SequenceFeatures
::
GetEmbedding
(
size_t
channel_id
,
size_t
target_index
)
const
{
DCHECK_LT
(
channel_id
,
num_channels
());
DCHECK_LT
(
target_index
,
num_steps
());
DCHECK
(
manager_
->
channel_configs_
[
channel_id
].
is_embedded
);
const
Channel
&
channel
=
channels_
[
channel_id
];
const
int32
id
=
channel
.
ids
[
target_index
];
return
id
<
0
?
Vector
<
float
>
(
zeros_
,
channel
.
embedding_matrix
.
num_columns
())
:
channel
.
embedding_matrix
.
row
(
id
);
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_SEQUENCE_FEATURES_H_
research/syntaxnet/dragnn/runtime/sequence_features_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_features.h"
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/fixed_embeddings.h"
#include "dragnn/runtime/sequence_extractor.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Number of transition steps to take in each component in the network.
const
size_t
kNumSteps
=
10
;
// A working one-channel ComponentSpec. This is intentionally identical to the
// first channel of |kMultiSpec|, so they can use the same embedding matrix.
const
char
kSingleSpec
[]
=
R"(fixed_feature {
vocabulary_size: 13
embedding_dim: 11
size: 1
})"
;
const
size_t
kSingleRows
=
13
;
const
size_t
kSingleColumns
=
11
;
constexpr
float
kSingleValue
=
1.25
;
// A working multi-channel ComponentSpec.
const
char
kMultiSpec
[]
=
R"(fixed_feature {
vocabulary_size: 13
embedding_dim: 11
size: 1
}
fixed_feature {
embedding_dim: -1
size: 1
}
fixed_feature {
embedding_dim: -1
size: 1
})"
;
// Fails to initialize.
class
FailToInitialize
:
public
SequenceExtractor
{
public:
// Implements SequenceExtractor.
bool
Supports
(
const
FixedFeatureChannel
&
,
const
ComponentSpec
&
component_spec
)
const
override
{
LOG
(
FATAL
)
<<
"Should never be called."
;
}
tensorflow
::
Status
Initialize
(
const
FixedFeatureChannel
&
,
const
ComponentSpec
&
)
override
{
return
tensorflow
::
errors
::
Internal
(
"No initialization for you!"
);
}
tensorflow
::
Status
GetIds
(
InputBatchCache
*
,
std
::
vector
<
int32
>
*
)
const
override
{
LOG
(
FATAL
)
<<
"Should never be called."
;
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR
(
FailToInitialize
);
// Initializes OK, then fails to extract features.
class
FailToGetIds
:
public
FailToInitialize
{
public:
// Implements SequenceExtractor.
tensorflow
::
Status
Initialize
(
const
FixedFeatureChannel
&
,
const
ComponentSpec
&
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
GetIds
(
InputBatchCache
*
,
std
::
vector
<
int32
>
*
)
const
override
{
return
tensorflow
::
errors
::
Internal
(
"No features for you!"
);
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR
(
FailToGetIds
);
// Initializes OK and extracts the previous step.
class
ExtractPrevious
:
public
FailToGetIds
{
public:
// Implements SequenceExtractor.
tensorflow
::
Status
GetIds
(
InputBatchCache
*
,
std
::
vector
<
int32
>
*
ids
)
const
override
{
ids
->
resize
(
kNumSteps
);
for
(
int
i
=
0
;
i
<
kNumSteps
;
++
i
)
(
*
ids
)[
i
]
=
i
-
1
;
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR
(
ExtractPrevious
);
// Initializes OK but produces the wrong number of features.
class
WrongNumberOfIds
:
public
FailToGetIds
{
public:
// Implements SequenceExtractor.
tensorflow
::
Status
GetIds
(
InputBatchCache
*
input
,
std
::
vector
<
int32
>
*
ids
)
const
override
{
ids
->
resize
(
kNumSteps
+
1
);
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR
(
WrongNumberOfIds
);
class
SequenceFeatureManagerTest
:
public
NetworkTestBase
{
protected:
// Creates a SequenceFeatureManager and returns the result of Reset()-ing it
// using the |component_spec_text|.
tensorflow
::
Status
ResetManager
(
const
string
&
component_spec_text
,
const
std
::
vector
<
string
>
&
sequence_extractor_types
)
{
ComponentSpec
component_spec
;
CHECK
(
TextFormat
::
ParseFromString
(
component_spec_text
,
&
component_spec
));
component_spec
.
set_name
(
kTestComponentName
);
AddFixedEmbeddingMatrix
(
0
,
kSingleRows
,
kSingleColumns
,
kSingleValue
);
AddComponent
(
kTestComponentName
);
TF_RETURN_IF_ERROR
(
fixed_embedding_manager_
.
Reset
(
component_spec
,
&
variable_store_
,
&
network_state_manager_
));
return
manager_
.
Reset
(
&
fixed_embedding_manager_
,
component_spec
,
sequence_extractor_types
);
}
FixedEmbeddingManager
fixed_embedding_manager_
;
SequenceFeatureManager
manager_
;
};
// Tests that SequenceFeatureManager is empty by default.
TEST_F
(
SequenceFeatureManagerTest
,
EmptyByDefault
)
{
EXPECT_EQ
(
manager_
.
num_channels
(),
0
);
}
// Tests that SequenceFeatureManager is empty when reset to an empty spec.
TEST_F
(
SequenceFeatureManagerTest
,
EmptySpec
)
{
TF_EXPECT_OK
(
ResetManager
(
""
,
{}));
EXPECT_EQ
(
manager_
.
num_channels
(),
0
);
}
// Tests that SequenceFeatureManager works with a single channel.
TEST_F
(
SequenceFeatureManagerTest
,
OneChannel
)
{
TF_EXPECT_OK
(
ResetManager
(
kSingleSpec
,
{
"ExtractPrevious"
}));
EXPECT_EQ
(
manager_
.
num_channels
(),
1
);
}
// Tests that SequenceFeatureManager works with multiple channels.
TEST_F
(
SequenceFeatureManagerTest
,
MultipleChannels
)
{
TF_EXPECT_OK
(
ResetManager
(
kMultiSpec
,
{
"ExtractPrevious"
,
"ExtractPrevious"
,
"ExtractPrevious"
}));
EXPECT_EQ
(
manager_
.
num_channels
(),
3
);
}
// Tests that SequenceFeatureManager fails if the FixedEmbeddingManager and
// ComponentSpec are mismatched.
TEST_F
(
SequenceFeatureManagerTest
,
MismatchedFixedManagerAndComponentSpec
)
{
ComponentSpec
component_spec
;
CHECK
(
TextFormat
::
ParseFromString
(
kMultiSpec
,
&
component_spec
));
component_spec
.
set_name
(
kTestComponentName
);
AddFixedEmbeddingMatrix
(
0
,
kSingleRows
,
kSingleColumns
,
kSingleValue
);
AddComponent
(
kTestComponentName
);
TF_ASSERT_OK
(
fixed_embedding_manager_
.
Reset
(
component_spec
,
&
variable_store_
,
&
network_state_manager_
));
// Remove one fixed feature, resulting in a mismatch.
component_spec
.
mutable_fixed_feature
()
->
RemoveLast
();
EXPECT_THAT
(
manager_
.
Reset
(
&
fixed_embedding_manager_
,
component_spec
,
{
"ExtractPrevious"
,
"ExtractPrevious"
,
"ExtractPrevious"
}),
test
::
IsErrorWithSubstr
(
"Channel mismatch between FixedEmbeddingManager "
"(3) and ComponentSpec (2)"
));
}
// Tests that SequenceFeatureManager fails if the FixedEmbeddingManager and
// SequenceExtractors are mismatched.
TEST_F
(
SequenceFeatureManagerTest
,
MismatchedFixedManagerAndSequenceExtractors
)
{
EXPECT_THAT
(
ResetManager
(
kMultiSpec
,
{
"ExtractPrevious"
,
"ExtractPrevious"
}),
test
::
IsErrorWithSubstr
(
"Channel mismatch between FixedEmbeddingManager "
"(3) and SequenceExtractors (2)"
));
}
// Tests that SequenceFeatureManager fails if a channel has multiple embeddings.
TEST_F
(
SequenceFeatureManagerTest
,
UnsupportedMultiEmbeddingChannel
)
{
const
string
kBadSpec
=
R"(fixed_feature {
vocabulary_size: 13
embedding_dim: 11
size: 2 # bad
})"
;
EXPECT_THAT
(
ResetManager
(
kBadSpec
,
{
"ExtractPrevious"
}),
test
::
IsErrorWithSubstr
(
"Multi-embedding fixed features are not supported"
));
}
// Tests that SequenceFeatureManager fails if one of the SequenceExtractors
// fails to initialize.
TEST_F
(
SequenceFeatureManagerTest
,
FailToInitializeSequenceExtractor
)
{
EXPECT_THAT
(
ResetManager
(
kMultiSpec
,
{
"ExtractPrevious"
,
"FailToInitialize"
,
"ExtractPrevious"
}),
test
::
IsErrorWithSubstr
(
"No initialization for you!"
));
}
// Tests that SequenceFeatureManager is OK even if the SequenceExtractors would
// fail in GetIds().
TEST_F
(
SequenceFeatureManagerTest
,
ManagerDoesntCareAboutGetIds
)
{
TF_EXPECT_OK
(
ResetManager
(
kMultiSpec
,
{
"FailToGetIds"
,
"FailToGetIds"
,
"FailToGetIds"
}));
}
class
SequenceFeaturesTest
:
public
SequenceFeatureManagerTest
{
protected:
// Resets the |sequence_features_| on the |manager_| and |input_batch_cache_|
// and returns the resulting status.
tensorflow
::
Status
ResetFeatures
()
{
return
sequence_features_
.
Reset
(
&
manager_
,
&
input_batch_cache_
);
}
InputBatchCache
input_batch_cache_
;
SequenceFeatures
sequence_features_
;
};
// Tests that SequenceFeatures is empty by default.
TEST_F
(
SequenceFeaturesTest
,
EmptyByDefault
)
{
EXPECT_EQ
(
sequence_features_
.
num_channels
(),
0
);
EXPECT_EQ
(
sequence_features_
.
num_steps
(),
0
);
}
// Tests that SequenceFeatures is empty when reset by an empty manager.
TEST_F
(
SequenceFeaturesTest
,
EmptyManager
)
{
TF_ASSERT_OK
(
ResetManager
(
""
,
{}));
TF_EXPECT_OK
(
ResetFeatures
());
EXPECT_EQ
(
sequence_features_
.
num_channels
(),
0
);
EXPECT_EQ
(
sequence_features_
.
num_steps
(),
0
);
}
// Tests that SequenceFeatures fails when one of the SequenceExtractors fails.
TEST_F
(
SequenceFeaturesTest
,
FailToGetIds
)
{
TF_ASSERT_OK
(
ResetManager
(
kMultiSpec
,
{
"ExtractPrevious"
,
"ExtractPrevious"
,
"FailToGetIds"
}));
EXPECT_THAT
(
ResetFeatures
(),
test
::
IsErrorWithSubstr
(
"No features for you!"
));
}
// Tests that SequenceFeatures fails when the SequenceExtractors produce
// different numbers of features.
TEST_F
(
SequenceFeaturesTest
,
MismatchedNumbersOfFeatures
)
{
TF_ASSERT_OK
(
ResetManager
(
kMultiSpec
,
{
"ExtractPrevious"
,
"ExtractPrevious"
,
"WrongNumberOfIds"
}));
EXPECT_THAT
(
ResetFeatures
(),
test
::
IsErrorWithSubstr
(
"Inconsistent feature sequence lengths at "
"channel ID 2: got 11 but expected 10"
));
}
// Tests that SequenceFeatures works as expected on one channel.
TEST_F
(
SequenceFeaturesTest
,
SingleChannel
)
{
TF_ASSERT_OK
(
ResetManager
(
kSingleSpec
,
{
"ExtractPrevious"
}));
TF_ASSERT_OK
(
ResetFeatures
());
ASSERT_EQ
(
sequence_features_
.
num_channels
(),
1
);
ASSERT_EQ
(
sequence_features_
.
num_steps
(),
kNumSteps
);
// ExtractPrevious extracts -1 for the 0'th target index, which indicates a
// missing ID and should be mapped to a zero vector.
ExpectVector
(
sequence_features_
.
GetEmbedding
(
0
,
0
),
kSingleColumns
,
0.0
);
EXPECT_DEBUG_DEATH
(
sequence_features_
.
GetId
(
0
,
0
),
"is_embedded"
);
// The remaining feature IDs map to valid embedding rows.
for
(
int
i
=
1
;
i
<
kNumSteps
;
++
i
)
{
ExpectVector
(
sequence_features_
.
GetEmbedding
(
0
,
i
),
kSingleColumns
,
kSingleValue
);
EXPECT_DEBUG_DEATH
(
sequence_features_
.
GetId
(
0
,
i
),
"is_embedded"
);
}
}
// Tests that SequenceFeatures works as expected on multiple channels.
TEST_F
(
SequenceFeaturesTest
,
ManyChannels
)
{
TF_ASSERT_OK
(
ResetManager
(
kMultiSpec
,
{
"ExtractPrevious"
,
"ExtractPrevious"
,
"ExtractPrevious"
}));
TF_ASSERT_OK
(
ResetFeatures
());
ASSERT_EQ
(
sequence_features_
.
num_channels
(),
3
);
ASSERT_EQ
(
sequence_features_
.
num_steps
(),
kNumSteps
);
// ExtractPrevious extracts -1 for the 0'th target index, which indicates a
// missing ID and should be mapped to a zero vector.
ExpectVector
(
sequence_features_
.
GetEmbedding
(
0
,
0
),
kSingleColumns
,
0.0
);
EXPECT_EQ
(
sequence_features_
.
GetId
(
1
,
0
),
-
1
);
EXPECT_EQ
(
sequence_features_
.
GetId
(
2
,
0
),
-
1
);
EXPECT_DEBUG_DEATH
(
sequence_features_
.
GetId
(
0
,
0
),
"is_embedded"
);
EXPECT_DEBUG_DEATH
(
sequence_features_
.
GetEmbedding
(
1
,
0
),
"is_embedded"
);
EXPECT_DEBUG_DEATH
(
sequence_features_
.
GetEmbedding
(
2
,
0
),
"is_embedded"
);
// The remaining features point to the previous item.
for
(
int
i
=
1
;
i
<
kNumSteps
;
++
i
)
{
ExpectVector
(
sequence_features_
.
GetEmbedding
(
0
,
i
),
kSingleColumns
,
kSingleValue
);
EXPECT_EQ
(
sequence_features_
.
GetId
(
1
,
i
),
i
-
1
);
EXPECT_EQ
(
sequence_features_
.
GetId
(
2
,
i
),
i
-
1
);
EXPECT_DEBUG_DEATH
(
sequence_features_
.
GetId
(
0
,
i
),
"is_embedded"
);
EXPECT_DEBUG_DEATH
(
sequence_features_
.
GetEmbedding
(
1
,
i
),
"is_embedded"
);
EXPECT_DEBUG_DEATH
(
sequence_features_
.
GetEmbedding
(
2
,
i
),
"is_embedded"
);
}
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/sequence_linker.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_linker.h"
#include "tensorflow/core/lib/core/errors.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
tensorflow
::
Status
SequenceLinker
::
Select
(
const
LinkedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
,
string
*
name
)
{
string
supporting_name
;
for
(
const
Registry
::
Registrar
*
registrar
=
registry
()
->
components
;
registrar
!=
nullptr
;
registrar
=
registrar
->
next
())
{
Factory
*
factory_function
=
registrar
->
object
();
std
::
unique_ptr
<
SequenceLinker
>
current_linker
(
factory_function
());
if
(
!
current_linker
->
Supports
(
channel
,
component_spec
))
continue
;
if
(
!
supporting_name
.
empty
())
{
return
tensorflow
::
errors
::
Internal
(
"Multiple SequenceLinkers support channel "
,
channel
.
ShortDebugString
(),
" of ComponentSpec ("
,
supporting_name
,
" and "
,
registrar
->
name
(),
"): "
,
component_spec
.
ShortDebugString
());
}
supporting_name
=
registrar
->
name
();
}
if
(
supporting_name
.
empty
())
{
return
tensorflow
::
errors
::
NotFound
(
"No SequenceLinker supports channel "
,
channel
.
ShortDebugString
(),
" of ComponentSpec: "
,
component_spec
.
ShortDebugString
());
}
// Success; make modifications.
*
name
=
supporting_name
;
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
SequenceLinker
::
New
(
const
string
&
name
,
const
LinkedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
,
std
::
unique_ptr
<
SequenceLinker
>
*
linker
)
{
std
::
unique_ptr
<
SequenceLinker
>
matching_linker
;
TF_RETURN_IF_ERROR
(
SequenceLinker
::
CreateOrError
(
name
,
&
matching_linker
));
TF_RETURN_IF_ERROR
(
matching_linker
->
Initialize
(
channel
,
component_spec
));
// Success; make modifications.
*
linker
=
std
::
move
(
matching_linker
);
return
tensorflow
::
Status
::
OK
();
}
}
// namespace runtime
}
// namespace dragnn
REGISTER_SYNTAXNET_CLASS_REGISTRY
(
"DRAGNN Runtime Sequence Linker"
,
dragnn
::
runtime
::
SequenceLinker
);
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/sequence_linker.h
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_SEQUENCE_LINKER_H_
#define DRAGNN_RUNTIME_SEQUENCE_LINKER_H_
#include <stddef.h>
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/protos/spec.pb.h"
#include "syntaxnet/base.h"
#include "syntaxnet/registry.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// Interface for link extraction for sequence inputs.
//
// This can be used to avoid ComputeSession overhead in simple cases; for
// example, extracting a sequence of identity or reverse-identity links.
class
SequenceLinker
:
public
RegisterableClass
<
SequenceLinker
>
{
public:
// Sets |linker| to an instance of the subclass named |name| initialized from
// the |channel| of the |component_spec|. On error, returns non-OK and
// modifies nothing.
static
tensorflow
::
Status
New
(
const
string
&
name
,
const
LinkedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
,
std
::
unique_ptr
<
SequenceLinker
>
*
linker
);
SequenceLinker
(
const
SequenceLinker
&
)
=
delete
;
SequenceLinker
&
operator
=
(
const
SequenceLinker
&
)
=
delete
;
virtual
~
SequenceLinker
()
=
default
;
// Sets |name| to the registered name of the SequenceLinker that supports the
// |channel| of the |component_spec|. On error, returns non-OK and modifies
// nothing. The returned statuses include:
// * OK: If a supporting SequenceLinker was found.
// * INTERNAL: If an error occurred while searching for a compatible match.
// * NOT_FOUND: If the search was error-free, but no compatible match was
// found.
static
tensorflow
::
Status
Select
(
const
LinkedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
,
string
*
name
);
// Overwrites |links| with the sequence of translated link step indices for
// the |input|. Specifically, sets links[i] to the (possibly out-of-bounds)
// step index to fetch from the source component for the i'th element of the
// target sequence. Assumes that |source_num_steps| is the number of steps
// taken by the source component. On error, returns non-OK.
virtual
tensorflow
::
Status
GetLinks
(
size_t
source_num_steps
,
InputBatchCache
*
input
,
std
::
vector
<
int32
>
*
links
)
const
=
0
;
protected:
SequenceLinker
()
=
default
;
private:
// Helps prevent use of the Create() method; use New() instead.
using
RegisterableClass
<
SequenceLinker
>::
Create
;
// Returns true if this supports the |channel| of the |component_spec|.
// Implementations must coordinate to ensure that at most one supports any
// given |component_spec|.
virtual
bool
Supports
(
const
LinkedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
)
const
=
0
;
// Initializes this from the |channel| of the |component_spec|. On error,
// returns non-OK.
virtual
tensorflow
::
Status
Initialize
(
const
LinkedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
)
=
0
;
};
}
// namespace runtime
}
// namespace dragnn
DECLARE_SYNTAXNET_CLASS_REGISTRY
(
"DRAGNN Runtime Sequence Linker"
,
dragnn
::
runtime
::
SequenceLinker
);
}
// namespace syntaxnet
#define DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER(subclass) \
REGISTER_SYNTAXNET_CLASS_COMPONENT( \
::syntaxnet::dragnn::runtime::SequenceLinker, #subclass, subclass)
#endif // DRAGNN_RUNTIME_SEQUENCE_LINKER_H_
research/syntaxnet/dragnn/runtime/sequence_linker_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_linker.h"
#include <stddef.h>
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Supports components named "success" and initializes successfully.
class
Success
:
public
SequenceLinker
{
public:
// Implements SequenceLinker.
bool
Supports
(
const
LinkedFeatureChannel
&
,
const
ComponentSpec
&
component_spec
)
const
override
{
return
component_spec
.
name
()
==
"success"
;
}
tensorflow
::
Status
Initialize
(
const
LinkedFeatureChannel
&
,
const
ComponentSpec
&
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
GetLinks
(
size_t
,
InputBatchCache
*
,
std
::
vector
<
int32
>
*
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER
(
Success
);
// Supports components named "failure" and fails to initialize.
class
Failure
:
public
SequenceLinker
{
public:
// Implements SequenceLinker.
bool
Supports
(
const
LinkedFeatureChannel
&
,
const
ComponentSpec
&
component_spec
)
const
override
{
return
component_spec
.
name
()
==
"failure"
;
}
tensorflow
::
Status
Initialize
(
const
LinkedFeatureChannel
&
,
const
ComponentSpec
&
)
override
{
return
tensorflow
::
errors
::
Internal
(
"Boom!"
);
}
tensorflow
::
Status
GetLinks
(
size_t
,
InputBatchCache
*
,
std
::
vector
<
int32
>
*
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER
(
Failure
);
// Supports components named "duplicate" and initializes successfully.
class
Duplicate
:
public
SequenceLinker
{
public:
// Implements SequenceLinker.
bool
Supports
(
const
LinkedFeatureChannel
&
,
const
ComponentSpec
&
component_spec
)
const
override
{
return
component_spec
.
name
()
==
"duplicate"
;
}
tensorflow
::
Status
Initialize
(
const
LinkedFeatureChannel
&
,
const
ComponentSpec
&
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
GetLinks
(
size_t
,
InputBatchCache
*
,
std
::
vector
<
int32
>
*
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER
(
Duplicate
);
// Duplicate of the above.
using
Duplicate2
=
Duplicate
;
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER
(
Duplicate2
);
// Tests that a component can be successfully created.
TEST
(
SequenceLinkerTest
,
Success
)
{
string
name
;
std
::
unique_ptr
<
SequenceLinker
>
linker
;
ComponentSpec
component_spec
;
component_spec
.
set_name
(
"success"
);
TF_ASSERT_OK
(
SequenceLinker
::
Select
({},
component_spec
,
&
name
));
ASSERT_EQ
(
name
,
"Success"
);
TF_EXPECT_OK
(
SequenceLinker
::
New
(
name
,
{},
component_spec
,
&
linker
));
EXPECT_NE
(
linker
,
nullptr
);
}
// Tests that errors in Initialize() are reported.
TEST
(
SequenceLinkerTest
,
FailToInitialize
)
{
string
name
;
std
::
unique_ptr
<
SequenceLinker
>
linker
;
ComponentSpec
component_spec
;
component_spec
.
set_name
(
"failure"
);
TF_ASSERT_OK
(
SequenceLinker
::
Select
({},
component_spec
,
&
name
));
EXPECT_EQ
(
name
,
"Failure"
);
EXPECT_THAT
(
SequenceLinker
::
New
(
name
,
{},
component_spec
,
&
linker
),
test
::
IsErrorWithSubstr
(
"Boom!"
));
EXPECT_EQ
(
linker
,
nullptr
);
}
// Tests that unsupported specs are reported as NOT_FOUND errors.
TEST
(
SequenceLinkerTest
,
UnsupportedSpec
)
{
string
name
=
"not overwritten"
;
ComponentSpec
component_spec
;
component_spec
.
set_name
(
"unsupported"
);
EXPECT_THAT
(
SequenceLinker
::
Select
({},
component_spec
,
&
name
),
test
::
IsErrorWithCodeAndSubstr
(
tensorflow
::
error
::
NOT_FOUND
,
"No SequenceLinker supports channel"
));
EXPECT_EQ
(
name
,
"not overwritten"
);
}
// Tests that unsupported subclass names are reported as errors.
TEST
(
SequenceLinkerTest
,
UnsupportedSubclass
)
{
std
::
unique_ptr
<
SequenceLinker
>
linker
;
ComponentSpec
component_spec
;
EXPECT_THAT
(
SequenceLinker
::
New
(
"Unsupported"
,
{},
component_spec
,
&
linker
),
test
::
IsErrorWithSubstr
(
"Unknown DRAGNN Runtime Sequence Linker"
));
EXPECT_EQ
(
linker
,
nullptr
);
}
// Tests that multiple supporting linkers are reported as INTERNAL errors.
TEST
(
SequenceLinkerTest
,
Duplicate
)
{
string
name
=
"not overwritten"
;
ComponentSpec
component_spec
;
component_spec
.
set_name
(
"duplicate"
);
EXPECT_THAT
(
SequenceLinker
::
Select
({},
component_spec
,
&
name
),
test
::
IsErrorWithCodeAndSubstr
(
tensorflow
::
error
::
INTERNAL
,
"Multiple SequenceLinkers support channel"
));
EXPECT_EQ
(
name
,
"not overwritten"
);
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/sequence_links.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_links.h"
#include <utility>
#include "tensorflow/core/lib/core/errors.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
tensorflow
::
Status
SequenceLinkManager
::
Reset
(
const
LinkedEmbeddingManager
*
linked_embedding_manager
,
const
ComponentSpec
&
component_spec
,
const
std
::
vector
<
string
>
&
sequence_linker_types
)
{
const
size_t
num_channels
=
linked_embedding_manager
->
channel_configs_
.
size
();
if
(
component_spec
.
linked_feature_size
()
!=
num_channels
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Channel mismatch between LinkedEmbeddingManager ("
,
num_channels
,
") and ComponentSpec ("
,
component_spec
.
linked_feature_size
(),
")"
);
}
if
(
sequence_linker_types
.
size
()
!=
num_channels
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Channel mismatch between LinkedEmbeddingManager ("
,
num_channels
,
") and SequenceLinkers ("
,
sequence_linker_types
.
size
(),
")"
);
}
for
(
const
LinkedFeatureChannel
&
channel
:
component_spec
.
linked_feature
())
{
if
(
channel
.
embedding_dim
()
>=
0
)
{
return
tensorflow
::
errors
::
Unimplemented
(
"Transformed linked features are not supported for channel: "
,
channel
.
ShortDebugString
());
}
}
std
::
vector
<
ChannelConfig
>
local_configs
;
// avoid modification on error
for
(
size_t
channel_id
=
0
;
channel_id
<
num_channels
;
++
channel_id
)
{
const
LinkedFeatureChannel
&
channel
=
component_spec
.
linked_feature
(
channel_id
);
local_configs
.
emplace_back
();
ChannelConfig
&
channel_config
=
local_configs
.
back
();
channel_config
.
is_recurrent
=
channel
.
source_component
()
==
component_spec
.
name
();
channel_config
.
handle
=
linked_embedding_manager
->
channel_configs_
[
channel_id
].
source_handle
;
TF_RETURN_IF_ERROR
(
SequenceLinker
::
New
(
sequence_linker_types
[
channel_id
],
component_spec
.
linked_feature
(
channel_id
),
component_spec
,
&
channel_config
.
linker
));
}
// Success; make modifications.
zeros_
=
linked_embedding_manager
->
zeros_
.
view
();
channel_configs_
=
std
::
move
(
local_configs
);
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
SequenceLinks
::
Reset
(
bool
add_steps
,
const
SequenceLinkManager
*
manager
,
NetworkStates
*
network_states
,
InputBatchCache
*
input
)
{
zeros_
=
manager
->
zeros_
;
num_channels_
=
manager
->
channel_configs_
.
size
();
num_steps_
=
0
;
bool
have_num_steps
=
false
;
// true if |num_steps_| was assigned
// Make sure |channels_| is big enough. Note that |channels_| never shrinks,
// so the Channel.links sub-vector is never deallocated.
if
(
num_channels_
>
channels_
.
size
())
channels_
.
resize
(
num_channels_
);
// Process non-recurrent links first.
for
(
int
channel_id
=
0
;
channel_id
<
num_channels_
;
++
channel_id
)
{
const
SequenceLinkManager
::
ChannelConfig
&
channel_config
=
manager
->
channel_configs_
[
channel_id
];
if
(
channel_config
.
is_recurrent
)
continue
;
Channel
&
channel
=
channels_
[
channel_id
];
channel
.
layer
=
network_states
->
GetLayer
(
channel_config
.
handle
);
TF_RETURN_IF_ERROR
(
channel_config
.
linker
->
GetLinks
(
channel
.
layer
.
num_rows
(),
input
,
&
channel
.
links
));
if
(
!
have_num_steps
)
{
num_steps_
=
channel
.
links
.
size
();
have_num_steps
=
true
;
}
else
if
(
channel
.
links
.
size
()
!=
num_steps_
)
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Inconsistent link sequence lengths at channel ID "
,
channel_id
,
": got "
,
channel
.
links
.
size
(),
" but expected "
,
num_steps_
);
}
}
// Add steps to the |network_states|, if requested.
if
(
add_steps
)
{
if
(
!
have_num_steps
)
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Cannot infer the number of steps to add because there are no "
"non-recurrent links"
);
}
network_states
->
AddSteps
(
num_steps_
);
}
// Process recurrent links. These require that the current component in the
// |network_states| has been sized to the proper number of steps.
for
(
int
channel_id
=
0
;
channel_id
<
num_channels_
;
++
channel_id
)
{
const
SequenceLinkManager
::
ChannelConfig
&
channel_config
=
manager
->
channel_configs_
[
channel_id
];
if
(
!
channel_config
.
is_recurrent
)
continue
;
Channel
&
channel
=
channels_
[
channel_id
];
channel
.
layer
=
network_states
->
GetLayer
(
channel_config
.
handle
);
TF_RETURN_IF_ERROR
(
channel_config
.
linker
->
GetLinks
(
channel
.
layer
.
num_rows
(),
input
,
&
channel
.
links
));
if
(
!
have_num_steps
)
{
num_steps_
=
channel
.
links
.
size
();
have_num_steps
=
true
;
}
else
if
(
channel
.
links
.
size
()
!=
num_steps_
)
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Inconsistent link sequence lengths at channel ID "
,
channel_id
,
": got "
,
channel
.
links
.
size
(),
" but expected "
,
num_steps_
);
}
}
return
tensorflow
::
Status
::
OK
();
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/sequence_links.h
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Utils for configuring and extracting linked embeddings for sequence-based
// models. Analogous to LinkedEmbeddingManager and LinkedEmbeddings, but uses
// SequenceLinker instead of ComputeSession.
#ifndef DRAGNN_RUNTIME_SEQUENCE_LINKS_H_
#define DRAGNN_RUNTIME_SEQUENCE_LINKS_H_
#include <stddef.h>
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/linked_embeddings.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/sequence_linker.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// Manager for linked embeddings for sequence-based models. This is a wrapper
// around the LinkedEmbeddingManager.
class
SequenceLinkManager
{
public:
// Creates an empty manager.
SequenceLinkManager
()
=
default
;
// Resets this to wrap the |linked_embedding_manager|, which must outlive
// this. The |sequence_linker_types| should name one SequenceLinker subclass
// per channel; e.g., {"IdentitySequenceLinker", "ReversedSequenceLinker"}.
// This initializes each SequenceLinker from the |component_spec|. On error,
// returns non-OK and does not modify this.
tensorflow
::
Status
Reset
(
const
LinkedEmbeddingManager
*
linked_embedding_manager
,
const
ComponentSpec
&
component_spec
,
const
std
::
vector
<
string
>
&
sequence_linker_types
);
// Accessors.
size_t
num_channels
()
const
{
return
channel_configs_
.
size
();
}
private:
friend
class
SequenceLinks
;
// Configuration for a single linked embedding channel.
struct
ChannelConfig
{
// Whether this link is recurrent.
bool
is_recurrent
=
false
;
// Handle to the source layer in the relevant NetworkStates.
LayerHandle
<
float
>
handle
;
// Extractor for sequences of translated link indices.
std
::
unique_ptr
<
SequenceLinker
>
linker
;
};
// Array of zeros that can be substituted for out-of-bounds embeddings. This
// is a reference to the corresponding array in the LinkedEmbeddingManager.
// See the large comment in linked_embeddings.cc for reference.
AlignedView
zeros_
;
// Ordered list of configurations for each channel.
std
::
vector
<
ChannelConfig
>
channel_configs_
;
};
// A set of linked embeddings for a sequence-based model. Configured by a
// SequenceLinkManager.
class
SequenceLinks
{
public:
// Creates an empty set of embeddings.
SequenceLinks
()
=
default
;
// Resets this to the sequences of linked embeddings managed by the |manager|
// on the |input|. Retrieves layers from the |network_states|. The |manager|
// must live until this is destroyed or Reset(), and should not be modified
// during that time. If |add_steps| is true, then infers the number of steps
// from the non-recurrent links and adds steps to the |network_states| before
// processing the recurrent links. On error, returns non-OK.
//
// NB: Recurrent links are tricky, because the |network_states| must be filled
// with steps before processing recurrent links. There are two approaches:
// 1. Add steps to the |network_states| before calling Reset(). This only
// works if the component also has fixed features, which can be used to
// infer the number of steps.
// 2. Set |add_steps| to true, so steps are added during Reset(). This only
// works if the component also has non-recurrent links, which can be used
// to infer the number of steps.
// If a component only has recurrent links then neither of the above works,
// but such a component would be nonsensical: it recurses on itself with no
// external input.
tensorflow
::
Status
Reset
(
bool
add_steps
,
const
SequenceLinkManager
*
manager
,
NetworkStates
*
network_states
,
InputBatchCache
*
input
);
// Retrieves the linked embedding for the |target_index|'th element of the
// |channel_id|'th channel. Sets |embedding| to the linked embedding vector
// and sets |is_out_of_bounds| to true if the link is out of bounds.
void
Get
(
size_t
channel_id
,
size_t
target_index
,
Vector
<
float
>
*
embedding
,
bool
*
is_out_of_bounds
)
const
;
// Accessors.
size_t
num_channels
()
const
{
return
num_channels_
;
}
size_t
num_steps
()
const
{
return
num_steps_
;
}
private:
// Data associated with a single linked embedding channel.
struct
Channel
{
// Source layer activations.
Matrix
<
float
>
layer
;
// Translated link indices for each step.
std
::
vector
<
int32
>
links
;
};
// Zero vector from the most recent Reset().
AlignedView
zeros_
;
// Number of channels and steps from the most recent Reset().
size_t
num_channels_
=
0
;
size_t
num_steps_
=
0
;
// Ordered list of linked embedding channels. This may contain more than
// |num_channels_| entries, to avoid deallocation/reallocation cycles, but
// only the first |num_channels_| entries are valid.
std
::
vector
<
Channel
>
channels_
;
};
// Implementation details below.
inline
void
SequenceLinks
::
Get
(
size_t
channel_id
,
size_t
target_index
,
Vector
<
float
>
*
embedding
,
bool
*
is_out_of_bounds
)
const
{
DCHECK_LT
(
channel_id
,
num_channels
());
DCHECK_LT
(
target_index
,
num_steps
());
const
Channel
&
channel
=
channels_
[
channel_id
];
const
int32
link
=
channel
.
links
[
target_index
];
*
is_out_of_bounds
=
(
link
<
0
||
link
>=
channel
.
layer
.
num_rows
());
if
(
*
is_out_of_bounds
)
{
*
embedding
=
Vector
<
float
>
(
zeros_
,
channel
.
layer
.
num_columns
());
}
else
{
*
embedding
=
channel
.
layer
.
row
(
link
);
}
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_SEQUENCE_LINKS_H_
research/syntaxnet/dragnn/runtime/sequence_links_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_links.h"
#include <stddef.h>
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/linked_embeddings.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/sequence_linker.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Dimensions of the layers in the network (see ResetManager() below).
const
size_t
kPrevious1LayerDim
=
16
;
const
size_t
kPrevious2LayerDim
=
32
;
const
size_t
kRecurrentLayerDim
=
48
;
// Number of transition steps to take in each component in the network.
const
size_t
kNumSteps
=
10
;
// A working one-channel ComponentSpec.
const
char
kSingleSpec
[]
=
R"(linked_feature {
embedding_dim: -1
source_component: 'source_component_1'
source_layer: 'previous_1'
size: 1
})"
;
// A working multi-channel ComponentSpec.
const
char
kMultiSpec
[]
=
R"(linked_feature {
embedding_dim: -1
source_component: 'source_component_1'
source_layer: 'previous_1'
size: 1
}
linked_feature {
embedding_dim: -1
source_component: 'source_component_2'
source_layer: 'previous_2'
size: 1
}
linked_feature {
embedding_dim: -1
source_component: 'test_component'
source_layer: 'recurrent'
size: 1
})"
;
// A recurrent-only ComponentSpec.
const
char
kRecurrentSpec
[]
=
R"(linked_feature {
embedding_dim: -1
source_component: 'test_component'
source_layer: 'recurrent'
size: 1
})"
;
// Fails to initialize.
class
FailToInitialize
:
public
SequenceLinker
{
public:
// Implements SequenceLinker.
bool
Supports
(
const
LinkedFeatureChannel
&
,
const
ComponentSpec
&
component_spec
)
const
override
{
LOG
(
FATAL
)
<<
"Should never be called."
;
}
tensorflow
::
Status
Initialize
(
const
LinkedFeatureChannel
&
,
const
ComponentSpec
&
)
override
{
return
tensorflow
::
errors
::
Internal
(
"No initialization for you!"
);
}
tensorflow
::
Status
GetLinks
(
size_t
,
InputBatchCache
*
,
std
::
vector
<
int32
>
*
)
const
override
{
LOG
(
FATAL
)
<<
"Should never be called."
;
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER
(
FailToInitialize
);
// Initializes OK, then fails to extract links.
class
FailToGetLinks
:
public
FailToInitialize
{
public:
// Implements SequenceLinker.
tensorflow
::
Status
Initialize
(
const
LinkedFeatureChannel
&
,
const
ComponentSpec
&
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
GetLinks
(
size_t
,
InputBatchCache
*
,
std
::
vector
<
int32
>
*
)
const
override
{
return
tensorflow
::
errors
::
Internal
(
"No links for you!"
);
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER
(
FailToGetLinks
);
// Initializes OK and links to the previous step.
class
LinkToPrevious
:
public
FailToGetLinks
{
public:
// Implements SequenceLinker.
tensorflow
::
Status
GetLinks
(
size_t
source_num_steps
,
InputBatchCache
*
,
std
::
vector
<
int32
>
*
links
)
const
override
{
links
->
resize
(
source_num_steps
);
for
(
int
i
=
0
;
i
<
links
->
size
();
++
i
)
(
*
links
)[
i
]
=
i
-
1
;
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER
(
LinkToPrevious
);
// Initializes OK but produces the wrong number of links.
class
WrongNumberOfLinks
:
public
FailToGetLinks
{
public:
// Implements SequenceLinker.
tensorflow
::
Status
GetLinks
(
size_t
,
InputBatchCache
*
,
std
::
vector
<
int32
>
*
links
)
const
override
{
links
->
resize
(
kNumSteps
+
1
);
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER
(
WrongNumberOfLinks
);
class
SequenceLinkManagerTest
:
public
NetworkTestBase
{
protected:
// Sets up previous components and layers.
void
AddComponentsAndLayers
()
{
AddComponent
(
"source_component_0"
);
AddComponent
(
"source_component_1"
);
AddLayer
(
"previous_1"
,
kPrevious1LayerDim
);
AddComponent
(
"source_component_2"
);
AddLayer
(
"previous_2"
,
kPrevious2LayerDim
);
AddComponent
(
kTestComponentName
);
AddLayer
(
"recurrent"
,
kRecurrentLayerDim
);
}
// Creates a SequenceLinkManager and returns the result of Reset()-ing it
// using the |component_spec_text|.
tensorflow
::
Status
ResetManager
(
const
string
&
component_spec_text
,
const
std
::
vector
<
string
>
&
sequence_linker_types
)
{
ComponentSpec
component_spec
;
CHECK
(
TextFormat
::
ParseFromString
(
component_spec_text
,
&
component_spec
));
component_spec
.
set_name
(
kTestComponentName
);
AddComponentsAndLayers
();
TF_RETURN_IF_ERROR
(
linked_embedding_manager_
.
Reset
(
component_spec
,
&
variable_store_
,
&
network_state_manager_
));
return
manager_
.
Reset
(
&
linked_embedding_manager_
,
component_spec
,
sequence_linker_types
);
}
LinkedEmbeddingManager
linked_embedding_manager_
;
SequenceLinkManager
manager_
;
};
// Tests that SequenceLinkManager is empty by default.
TEST_F
(
SequenceLinkManagerTest
,
EmptyByDefault
)
{
EXPECT_EQ
(
manager_
.
num_channels
(),
0
);
}
// Tests that SequenceLinkManager is empty when reset to an empty spec.
TEST_F
(
SequenceLinkManagerTest
,
EmptySpec
)
{
TF_EXPECT_OK
(
ResetManager
(
""
,
{}));
EXPECT_EQ
(
manager_
.
num_channels
(),
0
);
}
// Tests that SequenceLinkManager works with a single channel.
TEST_F
(
SequenceLinkManagerTest
,
OneChannel
)
{
TF_EXPECT_OK
(
ResetManager
(
kSingleSpec
,
{
"LinkToPrevious"
}));
EXPECT_EQ
(
manager_
.
num_channels
(),
1
);
}
// Tests that SequenceLinkManager works with multiple channels.
TEST_F
(
SequenceLinkManagerTest
,
MultipleChannels
)
{
TF_EXPECT_OK
(
ResetManager
(
kMultiSpec
,
{
"LinkToPrevious"
,
"LinkToPrevious"
,
"LinkToPrevious"
}));
EXPECT_EQ
(
manager_
.
num_channels
(),
3
);
}
// Tests that SequenceLinkManager fails if the LinkedEmbeddingManager and
// ComponentSpec are mismatched.
TEST_F
(
SequenceLinkManagerTest
,
MismatchedLinkedManagerAndComponentSpec
)
{
ComponentSpec
component_spec
;
CHECK
(
TextFormat
::
ParseFromString
(
kMultiSpec
,
&
component_spec
));
component_spec
.
set_name
(
kTestComponentName
);
AddComponentsAndLayers
();
TF_ASSERT_OK
(
linked_embedding_manager_
.
Reset
(
component_spec
,
&
variable_store_
,
&
network_state_manager_
));
// Remove one linked feature, resulting in a mismatch.
component_spec
.
mutable_linked_feature
()
->
RemoveLast
();
EXPECT_THAT
(
manager_
.
Reset
(
&
linked_embedding_manager_
,
component_spec
,
{
"LinkToPrevious"
,
"LinkToPrevious"
,
"LinkToPrevious"
}),
test
::
IsErrorWithSubstr
(
"Channel mismatch between LinkedEmbeddingManager "
"(3) and ComponentSpec (2)"
));
}
// Tests that SequenceLinkManager fails if the LinkedEmbeddingManager and
// SequenceLinkers are mismatched.
TEST_F
(
SequenceLinkManagerTest
,
MismatchedLinkedManagerAndSequenceLinkers
)
{
EXPECT_THAT
(
ResetManager
(
kMultiSpec
,
{
"LinkToPrevious"
,
"LinkToPrevious"
}),
test
::
IsErrorWithSubstr
(
"Channel mismatch between LinkedEmbeddingManager "
"(3) and SequenceLinkers (2)"
));
}
// Tests that SequenceLinkManager fails when the link is transformed.
TEST_F
(
SequenceLinkManagerTest
,
UnsupportedTransformedLink
)
{
const
string
kBadSpec
=
R"(linked_feature {
embedding_dim: 16 # bad
source_component: 'source_component_1'
source_layer: 'previous_1'
size: 1
})"
;
AddLinkedWeightMatrix
(
0
,
kPrevious1LayerDim
,
16
,
0.0
);
AddLinkedOutOfBoundsVector
(
0
,
16
,
0.0
);
EXPECT_THAT
(
ResetManager
(
kBadSpec
,
{
"LinkToPrevious"
}),
test
::
IsErrorWithSubstr
(
"Transformed linked features are not supported"
));
}
// Tests that SequenceLinkManager fails if one of the SequenceLinkers fails to
// initialize.
TEST_F
(
SequenceLinkManagerTest
,
FailToInitializeSequenceLinker
)
{
EXPECT_THAT
(
ResetManager
(
kMultiSpec
,
{
"LinkToPrevious"
,
"FailToInitialize"
,
"LinkToPrevious"
}),
test
::
IsErrorWithSubstr
(
"No initialization for you!"
));
}
// Tests that SequenceLinkManager is OK even if the SequenceLinkers would fail
// in GetLinks().
TEST_F
(
SequenceLinkManagerTest
,
ManagerDoesntCareAboutGetLinks
)
{
TF_EXPECT_OK
(
ResetManager
(
kMultiSpec
,
{
"FailToGetLinks"
,
"FailToGetLinks"
,
"FailToGetLinks"
}));
}
// Values to fill each layer with.
const
float
kPrevious1LayerValue
=
1.0
;
const
float
kPrevious2LayerValue
=
2.0
;
const
float
kRecurrentLayerValue
=
3.0
;
class
SequenceLinksTest
:
public
SequenceLinkManagerTest
{
protected:
// Resets the |sequence_links_| using the |manager_|, |network_states_|, and
// |input_batch_cache_|, and returns the resulting status. Passes |add_steps|
// to Reset() and advances the current component by |num_steps|.
tensorflow
::
Status
ResetLinks
(
bool
add_steps
=
false
,
size_t
num_steps
=
kNumSteps
)
{
network_states_
.
Reset
(
&
network_state_manager_
);
// Fill components with steps.
StartComponent
(
kNumSteps
);
// source_component_0
StartComponent
(
kNumSteps
);
// source_component_1
StartComponent
(
kNumSteps
);
// source_component_2
StartComponent
(
num_steps
);
// current component
// Fill layers with values.
FillLayer
(
"source_component_1"
,
"previous_1"
,
kPrevious1LayerValue
);
FillLayer
(
"source_component_2"
,
"previous_2"
,
kPrevious2LayerValue
);
FillLayer
(
kTestComponentName
,
"recurrent"
,
kRecurrentLayerValue
);
return
sequence_links_
.
Reset
(
add_steps
,
&
manager_
,
&
network_states_
,
&
input_batch_cache_
);
}
InputBatchCache
input_batch_cache_
;
SequenceLinks
sequence_links_
;
};
// Tests that SequenceLinks is empty by default.
TEST_F
(
SequenceLinksTest
,
EmptyByDefault
)
{
EXPECT_EQ
(
sequence_links_
.
num_channels
(),
0
);
EXPECT_EQ
(
sequence_links_
.
num_steps
(),
0
);
}
// Tests that SequenceLinks is empty when reset by an empty manager.
TEST_F
(
SequenceLinksTest
,
EmptyManager
)
{
TF_ASSERT_OK
(
ResetManager
(
""
,
{}));
TF_EXPECT_OK
(
ResetLinks
());
EXPECT_EQ
(
sequence_links_
.
num_channels
(),
0
);
EXPECT_EQ
(
sequence_links_
.
num_steps
(),
0
);
}
// Tests that SequenceLinks fails when one of the non-recurrent SequenceLinkers
// fails.
TEST_F
(
SequenceLinksTest
,
FailToGetNonRecurrentLinks
)
{
TF_ASSERT_OK
(
ResetManager
(
kMultiSpec
,
{
"LinkToPrevious"
,
"FailToGetLinks"
,
"LinkToPrevious"
}));
EXPECT_THAT
(
ResetLinks
(),
test
::
IsErrorWithSubstr
(
"No links for you!"
));
}
// Tests that SequenceLinks fails when one of the recurrent SequenceLinkers
// fails.
TEST_F
(
SequenceLinksTest
,
FailToGetRecurrentLinks
)
{
TF_ASSERT_OK
(
ResetManager
(
kMultiSpec
,
{
"LinkToPrevious"
,
"LinkToPrevious"
,
"FailToGetLinks"
}));
EXPECT_THAT
(
ResetLinks
(),
test
::
IsErrorWithSubstr
(
"No links for you!"
));
}
// Tests that SequenceLinks fails when the non-recurrent SequenceLinkers produce
// different numbers of links.
TEST_F
(
SequenceLinksTest
,
MismatchedNumbersOfNonRecurrentLinks
)
{
TF_ASSERT_OK
(
ResetManager
(
kMultiSpec
,
{
"LinkToPrevious"
,
"WrongNumberOfLinks"
,
"LinkToPrevious"
}));
EXPECT_THAT
(
ResetLinks
(),
test
::
IsErrorWithSubstr
(
"Inconsistent link sequence lengths at "
"channel ID 1: got 11 but expected 10"
));
}
// Tests that SequenceLinks fails when the recurrent SequenceLinkers produce
// different numbers of links.
TEST_F
(
SequenceLinksTest
,
MismatchedNumbersOfRecurrentLinks
)
{
TF_ASSERT_OK
(
ResetManager
(
kMultiSpec
,
{
"LinkToPrevious"
,
"LinkToPrevious"
,
"WrongNumberOfLinks"
}));
EXPECT_THAT
(
ResetLinks
(),
test
::
IsErrorWithSubstr
(
"Inconsistent link sequence lengths at "
"channel ID 2: got 11 but expected 10"
));
}
// Tests that SequenceLinks works as expected on one channel.
TEST_F
(
SequenceLinksTest
,
SingleChannel
)
{
TF_ASSERT_OK
(
ResetManager
(
kSingleSpec
,
{
"LinkToPrevious"
}));
TF_ASSERT_OK
(
ResetLinks
());
ASSERT_EQ
(
sequence_links_
.
num_channels
(),
1
);
ASSERT_EQ
(
sequence_links_
.
num_steps
(),
kNumSteps
);
const
Matrix
<
float
>
previous1
(
GetLayer
(
"source_component_1"
,
"previous_1"
));
Vector
<
float
>
embedding
;
bool
is_out_of_bounds
=
false
;
// LinkToPrevious links the 0'th index to -1, which is out of bounds.
sequence_links_
.
Get
(
0
,
0
,
&
embedding
,
&
is_out_of_bounds
);
EXPECT_TRUE
(
is_out_of_bounds
);
ExpectVector
(
embedding
,
kPrevious1LayerDim
,
0.0
);
// The remaining links point to the previous item.
for
(
int
i
=
1
;
i
<
kNumSteps
;
++
i
)
{
sequence_links_
.
Get
(
0
,
i
,
&
embedding
,
&
is_out_of_bounds
);
EXPECT_FALSE
(
is_out_of_bounds
);
ExpectVector
(
embedding
,
kPrevious1LayerDim
,
kPrevious1LayerValue
);
EXPECT_EQ
(
embedding
.
data
(),
previous1
.
row
(
i
-
1
).
data
());
}
}
// Tests that SequenceLinks works as expected on multiple channels.
TEST_F
(
SequenceLinksTest
,
ManyChannels
)
{
TF_ASSERT_OK
(
ResetManager
(
kMultiSpec
,
{
"LinkToPrevious"
,
"LinkToPrevious"
,
"LinkToPrevious"
}));
TF_ASSERT_OK
(
ResetLinks
());
ASSERT_EQ
(
sequence_links_
.
num_channels
(),
3
);
ASSERT_EQ
(
sequence_links_
.
num_steps
(),
kNumSteps
);
const
Matrix
<
float
>
previous1
(
GetLayer
(
"source_component_1"
,
"previous_1"
));
const
Matrix
<
float
>
previous2
(
GetLayer
(
"source_component_2"
,
"previous_2"
));
const
Matrix
<
float
>
recurrent
(
GetLayer
(
kTestComponentName
,
"recurrent"
));
Vector
<
float
>
embedding
;
bool
is_out_of_bounds
=
false
;
// LinkToPrevious links the 0'th index to -1, which is out of bounds.
sequence_links_
.
Get
(
0
,
0
,
&
embedding
,
&
is_out_of_bounds
);
EXPECT_TRUE
(
is_out_of_bounds
);
ExpectVector
(
embedding
,
kPrevious1LayerDim
,
0.0
);
sequence_links_
.
Get
(
1
,
0
,
&
embedding
,
&
is_out_of_bounds
);
EXPECT_TRUE
(
is_out_of_bounds
);
ExpectVector
(
embedding
,
kPrevious2LayerDim
,
0.0
);
sequence_links_
.
Get
(
2
,
0
,
&
embedding
,
&
is_out_of_bounds
);
EXPECT_TRUE
(
is_out_of_bounds
);
ExpectVector
(
embedding
,
kRecurrentLayerDim
,
0.0
);
// The remaining links point to the previous item.
for
(
int
i
=
1
;
i
<
kNumSteps
;
++
i
)
{
sequence_links_
.
Get
(
0
,
i
,
&
embedding
,
&
is_out_of_bounds
);
EXPECT_FALSE
(
is_out_of_bounds
);
ExpectVector
(
embedding
,
kPrevious1LayerDim
,
kPrevious1LayerValue
);
EXPECT_EQ
(
embedding
.
data
(),
previous1
.
row
(
i
-
1
).
data
());
sequence_links_
.
Get
(
1
,
i
,
&
embedding
,
&
is_out_of_bounds
);
EXPECT_FALSE
(
is_out_of_bounds
);
ExpectVector
(
embedding
,
kPrevious2LayerDim
,
kPrevious2LayerValue
);
EXPECT_EQ
(
embedding
.
data
(),
previous2
.
row
(
i
-
1
).
data
());
sequence_links_
.
Get
(
2
,
i
,
&
embedding
,
&
is_out_of_bounds
);
EXPECT_FALSE
(
is_out_of_bounds
);
ExpectVector
(
embedding
,
kRecurrentLayerDim
,
kRecurrentLayerValue
);
EXPECT_EQ
(
embedding
.
data
(),
recurrent
.
row
(
i
-
1
).
data
());
}
}
// Tests that SequenceLinks is emptied when resetting to an empty manager after
// being reset to a non-empty manager.
TEST_F
(
SequenceLinksTest
,
ResetToEmptyAfterNonEmpty
)
{
TF_ASSERT_OK
(
ResetManager
(
kSingleSpec
,
{
"LinkToPrevious"
}));
TF_ASSERT_OK
(
ResetLinks
());
ASSERT_EQ
(
sequence_links_
.
num_channels
(),
1
);
ASSERT_EQ
(
sequence_links_
.
num_steps
(),
kNumSteps
);
SequenceLinkManager
manager
;
TF_ASSERT_OK
(
sequence_links_
.
Reset
(
/*add_steps=*/
false
,
&
manager
,
&
network_states_
,
&
input_batch_cache_
));
ASSERT_EQ
(
sequence_links_
.
num_channels
(),
0
);
ASSERT_EQ
(
sequence_links_
.
num_steps
(),
0
);
}
// Tests that SequenceLinks fails when adding steps to a component with no
// non-recurrent links.
TEST_F
(
SequenceLinksTest
,
AddStepsWithNoNonRecurrentLinks
)
{
TF_ASSERT_OK
(
ResetManager
(
kRecurrentSpec
,
{
"LinkToPrevious"
}));
EXPECT_THAT
(
ResetLinks
(
/*add_steps=*/
true
),
test
::
IsErrorWithSubstr
(
"Cannot infer the number of steps to add because "
"there are no non-recurrent links"
));
}
// Tests that SequenceLinks produces no links when processing a component with
// only recurrent links, and when the NetworkStates has no steps.
TEST_F
(
SequenceLinksTest
,
RecurrentLinksWithNoSteps
)
{
TF_ASSERT_OK
(
ResetManager
(
kRecurrentSpec
,
{
"LinkToPrevious"
}));
TF_ASSERT_OK
(
ResetLinks
(
/*add_steps=*/
false
,
/*num_steps=*/
0
));
ASSERT_EQ
(
sequence_links_
.
num_channels
(),
1
);
ASSERT_EQ
(
sequence_links_
.
num_steps
(),
0
);
}
// Tests that SequenceLinks properly infers the number of steps and adds them
// when processing a component with both non-recurrent and recurrent links.
TEST_F
(
SequenceLinksTest
,
AddStepsWithNonRecurrentAndRecurrentLinks
)
{
TF_ASSERT_OK
(
ResetManager
(
kMultiSpec
,
{
"LinkToPrevious"
,
"LinkToPrevious"
,
"LinkToPrevious"
}));
TF_ASSERT_OK
(
ResetLinks
(
/*add_steps=*/
true
,
/*num_steps=*/
0
));
ASSERT_EQ
(
sequence_links_
.
num_channels
(),
3
);
ASSERT_EQ
(
sequence_links_
.
num_steps
(),
kNumSteps
);
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
Prev
1
…
5
6
7
8
9
10
11
12
13
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment