Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
edea2b67
Commit
edea2b67
authored
May 11, 2018
by
Terry Koo
Browse files
Remove runtime because reasons.
parent
a4bb31d0
Changes
291
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
3183 deletions
+0
-3183
research/syntaxnet/dragnn/runtime/sequence_model.cc
research/syntaxnet/dragnn/runtime/sequence_model.cc
+0
-193
research/syntaxnet/dragnn/runtime/sequence_model.h
research/syntaxnet/dragnn/runtime/sequence_model.h
+0
-143
research/syntaxnet/dragnn/runtime/sequence_model_test.cc
research/syntaxnet/dragnn/runtime/sequence_model_test.cc
+0
-550
research/syntaxnet/dragnn/runtime/sequence_predictor.cc
research/syntaxnet/dragnn/runtime/sequence_predictor.cc
+0
-73
research/syntaxnet/dragnn/runtime/sequence_predictor.h
research/syntaxnet/dragnn/runtime/sequence_predictor.h
+0
-94
research/syntaxnet/dragnn/runtime/sequence_predictor_test.cc
research/syntaxnet/dragnn/runtime/sequence_predictor_test.cc
+0
-158
research/syntaxnet/dragnn/runtime/session_state.h
research/syntaxnet/dragnn/runtime/session_state.h
+0
-42
research/syntaxnet/dragnn/runtime/session_state_pool.cc
research/syntaxnet/dragnn/runtime/session_state_pool.cc
+0
-57
research/syntaxnet/dragnn/runtime/session_state_pool.h
research/syntaxnet/dragnn/runtime/session_state_pool.h
+0
-103
research/syntaxnet/dragnn/runtime/session_state_pool_test.cc
research/syntaxnet/dragnn/runtime/session_state_pool_test.cc
+0
-85
research/syntaxnet/dragnn/runtime/stateless_component_transformer.cc
...ntaxnet/dragnn/runtime/stateless_component_transformer.cc
+0
-60
research/syntaxnet/dragnn/runtime/stateless_component_transformer_test.cc
...et/dragnn/runtime/stateless_component_transformer_test.cc
+0
-63
research/syntaxnet/dragnn/runtime/syntaxnet_character_sequence_extractor.cc
.../dragnn/runtime/syntaxnet_character_sequence_extractor.cc
+0
-153
research/syntaxnet/dragnn/runtime/syntaxnet_character_sequence_extractor_test.cc
...nn/runtime/syntaxnet_character_sequence_extractor_test.cc
+0
-195
research/syntaxnet/dragnn/runtime/syntaxnet_character_sequence_linkers.cc
...et/dragnn/runtime/syntaxnet_character_sequence_linkers.cc
+0
-216
research/syntaxnet/dragnn/runtime/syntaxnet_character_sequence_linkers_test.cc
...agnn/runtime/syntaxnet_character_sequence_linkers_test.cc
+0
-304
research/syntaxnet/dragnn/runtime/syntaxnet_head_selection_component.cc
...xnet/dragnn/runtime/syntaxnet_head_selection_component.cc
+0
-90
research/syntaxnet/dragnn/runtime/syntaxnet_head_selection_component_test.cc
...dragnn/runtime/syntaxnet_head_selection_component_test.cc
+0
-256
research/syntaxnet/dragnn/runtime/syntaxnet_mst_solver_component.cc
...yntaxnet/dragnn/runtime/syntaxnet_mst_solver_component.cc
+0
-93
research/syntaxnet/dragnn/runtime/syntaxnet_mst_solver_component_test.cc
...net/dragnn/runtime/syntaxnet_mst_solver_component_test.cc
+0
-255
No files found.
Too many changes to show.
To preserve performance only
291 of 291+
files are displayed.
Plain diff
Email patch
research/syntaxnet/dragnn/runtime/sequence_model.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_model.h"
#include <vector>
#include "dragnn/runtime/attributes.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/sequence_backend.h"
#include "dragnn/runtime/transition_system_traits.h"
#include "tensorflow/core/lib/core/errors.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Proper backend for sequence-based models.
constexpr
char
kSupportedBackend
[]
=
"SequenceBackend"
;
// Attributes for sequence-based comopnents, attached to the component builder.
// See SequenceComponentTransformer.
struct
ComponentBuilderAttributes
:
public
Attributes
{
// Registered names of the sequence extractors to use.
Mandatory
<
std
::
vector
<
string
>>
sequence_extractors
{
"sequence_extractors"
,
this
};
// Registered names of the sequence linkers to use per channel, in order.
Mandatory
<
std
::
vector
<
string
>>
sequence_linkers
{
"sequence_linkers"
,
this
};
// Registered name of the sequence predictor to use.
Mandatory
<
string
>
sequence_predictor
{
"sequence_predictor"
,
this
};
};
}
// namespace
bool
SequenceModel
::
Supports
(
const
ComponentSpec
&
component_spec
)
{
// Require single-embedding fixed and linked features.
for
(
const
FixedFeatureChannel
&
channel
:
component_spec
.
fixed_feature
())
{
if
(
channel
.
size
()
!=
1
)
return
false
;
}
for
(
const
LinkedFeatureChannel
&
channel
:
component_spec
.
linked_feature
())
{
if
(
channel
.
size
()
!=
1
)
return
false
;
}
const
bool
has_fixed_feature
=
component_spec
.
fixed_feature_size
()
>
0
;
bool
has_recurrent_link
=
false
;
bool
has_non_recurrent_link
=
false
;
for
(
const
LinkedFeatureChannel
&
channel
:
component_spec
.
linked_feature
())
{
if
(
channel
.
source_component
()
==
component_spec
.
name
())
{
has_recurrent_link
=
true
;
}
else
{
has_non_recurrent_link
=
true
;
}
}
// Recurrent links must be accompanied by fixed features or non-recurrent
// links, so the number of recurrent steps can be pre-computed.
if
(
has_recurrent_link
&&
!
has_fixed_feature
&&
!
has_non_recurrent_link
)
{
return
false
;
}
const
int
num_features
=
component_spec
.
fixed_feature_size
()
+
component_spec
.
linked_feature_size
();
return
component_spec
.
backend
().
registered_name
()
==
kSupportedBackend
&&
num_features
>
0
;
}
tensorflow
::
Status
SequenceModel
::
Initialize
(
const
ComponentSpec
&
component_spec
,
const
string
&
logits_name
,
const
FixedEmbeddingManager
*
fixed_embedding_manager
,
const
LinkedEmbeddingManager
*
linked_embedding_manager
,
NetworkStateManager
*
network_state_manager
)
{
component_name_
=
component_spec
.
name
();
if
(
component_spec
.
backend
().
registered_name
()
!=
kSupportedBackend
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Invalid component backend: "
,
component_spec
.
backend
().
registered_name
());
}
TransitionSystemTraits
traits
(
component_spec
);
deterministic_
=
traits
.
is_deterministic
;
left_to_right_
=
traits
.
is_left_to_right
;
ComponentBuilderAttributes
component_builder_attributes
;
TF_RETURN_IF_ERROR
(
component_builder_attributes
.
Reset
(
component_spec
.
component_builder
().
parameters
()));
TF_RETURN_IF_ERROR
(
sequence_feature_manager_
.
Reset
(
fixed_embedding_manager
,
component_spec
,
component_builder_attributes
.
sequence_extractors
()));
TF_RETURN_IF_ERROR
(
sequence_link_manager_
.
Reset
(
linked_embedding_manager
,
component_spec
,
component_builder_attributes
.
sequence_linkers
()));
have_fixed_features_
=
sequence_feature_manager_
.
num_channels
()
>
0
;
have_linked_features_
=
sequence_link_manager_
.
num_channels
()
>
0
;
if
(
!
have_fixed_features_
&&
!
have_linked_features_
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"No fixed or linked features"
);
}
if
(
!
deterministic_
)
{
size_t
dimension
=
0
;
TF_RETURN_IF_ERROR
(
network_state_manager
->
LookupLayer
(
component_name_
,
logits_name
,
&
dimension
,
&
logits_handle_
));
if
(
dimension
!=
component_spec
.
num_actions
())
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Logits dimension mismatch between NetworkStates ("
,
dimension
,
") and ComponentSpec ("
,
component_spec
.
num_actions
(),
")"
);
}
TF_RETURN_IF_ERROR
(
SequencePredictor
::
New
(
component_builder_attributes
.
sequence_predictor
(),
component_spec
,
&
sequence_predictor_
));
}
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
SequenceModel
::
Preprocess
(
SessionState
*
session_state
,
ComputeSession
*
compute_session
,
EvaluateState
*
evaluate_state
)
const
{
InputBatchCache
*
input_batch_cache
=
compute_session
->
GetInputBatchCache
();
if
(
input_batch_cache
==
nullptr
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Null input batch"
);
}
// The feature handling below is complicated by the need to support recurrent
// links. See the comment on SequenceLinks::Reset().
NetworkStates
&
network_states
=
session_state
->
network_states
;
TF_RETURN_IF_ERROR
(
evaluate_state
->
features
.
Reset
(
&
sequence_feature_manager_
,
input_batch_cache
));
if
(
have_fixed_features_
)
{
network_states
.
AddSteps
(
evaluate_state
->
features
.
num_steps
());
}
TF_RETURN_IF_ERROR
(
evaluate_state
->
links
.
Reset
(
/*add_steps=*/
!
have_fixed_features_
,
&
sequence_link_manager_
,
&
network_states
,
input_batch_cache
));
// Initialize() ensures that there is at least one fixed or linked feature;
// use it to determine the number of steps.
size_t
num_steps
=
0
;
if
(
have_fixed_features_
&&
have_linked_features_
)
{
num_steps
=
evaluate_state
->
features
.
num_steps
();
if
(
num_steps
!=
evaluate_state
->
links
.
num_steps
())
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Sequence length mismatch between fixed features ("
,
num_steps
,
") and linked features ("
,
evaluate_state
->
links
.
num_steps
(),
")"
);
}
}
else
if
(
have_fixed_features_
)
{
num_steps
=
evaluate_state
->
features
.
num_steps
();
}
else
{
num_steps
=
evaluate_state
->
links
.
num_steps
();
}
// Tell the backend the current input size, so it can handle requests for
// linked features from downstream components.
static_cast
<
SequenceBackend
*>
(
compute_session
->
GetReadiedComponent
(
component_name_
))
->
SetSequenceSize
(
num_steps
);
evaluate_state
->
num_steps
=
num_steps
;
evaluate_state
->
input
=
input_batch_cache
;
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
SequenceModel
::
Predict
(
const
NetworkStates
&
network_states
,
EvaluateState
*
evaluate_state
)
const
{
if
(
!
deterministic_
)
{
const
Matrix
<
float
>
logits
(
network_states
.
GetLayer
(
logits_handle_
));
TF_RETURN_IF_ERROR
(
sequence_predictor_
->
Predict
(
logits
,
evaluate_state
->
input
));
}
return
tensorflow
::
Status
::
OK
();
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/sequence_model.h
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_SEQUENCE_MODEL_H_
#define DRAGNN_RUNTIME_SEQUENCE_MODEL_H_
#include <stddef.h>
#include <memory>
#include <string>
#include "dragnn/core/compute_session.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/fixed_embeddings.h"
#include "dragnn/runtime/linked_embeddings.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/sequence_features.h"
#include "dragnn/runtime/sequence_links.h"
#include "dragnn/runtime/sequence_predictor.h"
#include "dragnn/runtime/session_state.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// A class that configures and helps evaluate a sequence-based model.
//
// This class requires the SequenceBackend component backend and elides most of
// the ComputeSession feature extraction and transition system overhead.
class
SequenceModel
{
public:
// State associated with a single evaluation of the model.
struct
EvaluateState
{
// Number of transition steps in the current sequence.
size_t
num_steps
=
0
;
// Current input batch.
InputBatchCache
*
input
=
nullptr
;
// Sequence-based fixed features.
SequenceFeatures
features
;
// Sequence-based linked embeddings.
SequenceLinks
links
;
};
// Creates an uninitialized model. Call Initialize() before use.
SequenceModel
()
=
default
;
// Returns true if the |component_spec| is compatible with a sequence model.
static
bool
Supports
(
const
ComponentSpec
&
component_spec
);
// Initalizes this from the configuration in the |component_spec|. Wraps the
// |fixed_embedding_manager| and |linked_embedding_manager| in sequence-based
// versions, and requests layers from the |network_state_manager|. All of the
// managers must outlive this. If the transition system is non-deterministic,
// uses the layer named |logits_name| to make predictions later in Predict();
// otherwise, |logits_name| is ignored and Predict() does nothing. On error,
// returns non-OK.
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
component_spec
,
const
string
&
logits_name
,
const
FixedEmbeddingManager
*
fixed_embedding_manager
,
const
LinkedEmbeddingManager
*
linked_embedding_manager
,
NetworkStateManager
*
network_state_manager
);
// Resets the |evaluate_state| to values derived from the |session_state| and
// |compute_session|. Also updates the NetworkStates in the |session_state|
// and the current component of the |compute_session| with the length of the
// current sequence. Call this before producing output layers. On error,
// returns non-OK.
tensorflow
::
Status
Preprocess
(
SessionState
*
session_state
,
ComputeSession
*
compute_session
,
EvaluateState
*
evaluate_state
)
const
;
// If applicable, makes predictions based on the logits in |network_states|
// and applies them to the input in the |evaluate_state|. Call this after
// producing output layers. On error, returns non-OK.
tensorflow
::
Status
Predict
(
const
NetworkStates
&
network_states
,
EvaluateState
*
evaluate_state
)
const
;
// Accessors.
bool
deterministic
()
const
{
return
deterministic_
;
}
bool
left_to_right
()
const
{
return
left_to_right_
;
}
const
SequenceLinkManager
&
sequence_link_manager
()
const
;
const
SequenceFeatureManager
&
sequence_feature_manager
()
const
;
private:
// Name of the component that this model is a part of.
string
component_name_
;
// Whether the underlying transition system is deterministic.
bool
deterministic_
=
false
;
// Whether to process sequences from left to right.
bool
left_to_right_
=
true
;
// Whether fixed or linked features are present.
bool
have_fixed_features_
=
false
;
bool
have_linked_features_
=
false
;
// Handle to the logits layer. Only used if |deterministic_| is false.
LayerHandle
<
float
>
logits_handle_
;
// Manager for sequence-based feature extractors.
SequenceFeatureManager
sequence_feature_manager_
;
// Manager for sequence-based linked embeddings.
SequenceLinkManager
sequence_link_manager_
;
// Sequence-based predictor, if |deterministic_| is false.
std
::
unique_ptr
<
SequencePredictor
>
sequence_predictor_
;
};
// Implementation details below.
inline
const
SequenceLinkManager
&
SequenceModel
::
sequence_link_manager
()
const
{
return
sequence_link_manager_
;
}
inline
const
SequenceFeatureManager
&
SequenceModel
::
sequence_feature_manager
()
const
{
return
sequence_feature_manager_
;
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_SEQUENCE_MODEL_H_
research/syntaxnet/dragnn/runtime/sequence_model_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_model.h"
#include <string>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/fixed_embeddings.h"
#include "dragnn/runtime/linked_embeddings.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/sequence_backend.h"
#include "dragnn/runtime/sequence_extractor.h"
#include "dragnn/runtime/sequence_linker.h"
#include "dragnn/runtime/sequence_predictor.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "syntaxnet/base.h"
#include <gmock/gmock.h>
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
using
::
testing
::
Return
;
constexpr
int
kNumSteps
=
50
;
constexpr
int
kVocabularySize
=
123
;
constexpr
int
kLinkedDim
=
11
;
constexpr
int
kLogitsDim
=
17
;
constexpr
char
kLogitsName
[]
=
"oddly_named_logits"
;
constexpr
char
kPreviousComponentName
[]
=
"previous_component"
;
constexpr
char
kPreviousLayerName
[]
=
"previous_layer"
;
constexpr
float
kPreviousLayerValue
=
-
1.0
;
// Sequence extractor that extracts [0, 2, 4, ...].
class
EvenNumbers
:
public
SequenceExtractor
{
public:
// Implements SequenceExtractor.
bool
Supports
(
const
FixedFeatureChannel
&
,
const
ComponentSpec
&
)
const
override
{
return
true
;
}
tensorflow
::
Status
Initialize
(
const
FixedFeatureChannel
&
,
const
ComponentSpec
&
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
GetIds
(
InputBatchCache
*
,
std
::
vector
<
int32
>
*
ids
)
const
override
{
ids
->
clear
();
for
(
int
i
=
0
;
i
<
num_steps_
;
++
i
)
ids
->
push_back
(
2
*
i
);
return
tensorflow
::
Status
::
OK
();
}
// Sets the number of steps to emit.
static
void
SetNumSteps
(
int
num_steps
)
{
num_steps_
=
num_steps
;
}
private:
// The number of steps to produce.
static
int
num_steps_
;
};
int
EvenNumbers
::
num_steps_
=
kNumSteps
;
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR
(
EvenNumbers
);
// Trivial linker that links each index to the previous one.
class
LinkToPrevious
:
public
SequenceLinker
{
public:
// Implements SequenceLinker.
bool
Supports
(
const
LinkedFeatureChannel
&
,
const
ComponentSpec
&
)
const
override
{
return
true
;
}
tensorflow
::
Status
Initialize
(
const
LinkedFeatureChannel
&
,
const
ComponentSpec
&
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
GetLinks
(
size_t
,
InputBatchCache
*
,
std
::
vector
<
int32
>
*
links
)
const
override
{
links
->
clear
();
for
(
int
i
=
0
;
i
<
num_steps_
;
++
i
)
links
->
push_back
(
i
-
1
);
return
tensorflow
::
Status
::
OK
();
}
// Sets the number of steps to emit.
static
void
SetNumSteps
(
int
num_steps
)
{
num_steps_
=
num_steps
;
}
private:
// The number of steps to produce.
static
int
num_steps_
;
};
int
LinkToPrevious
::
num_steps_
=
kNumSteps
;
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER
(
LinkToPrevious
);
// Trivial predictor that captures the prediction logits.
class
CaptureLogits
:
public
SequencePredictor
{
public:
// Implements SequenceLinker.
bool
Supports
(
const
ComponentSpec
&
)
const
override
{
return
true
;
}
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
Predict
(
Matrix
<
float
>
logits
,
InputBatchCache
*
)
const
override
{
GetLogits
()
=
logits
;
return
tensorflow
::
Status
::
OK
();
}
// Returns the captured logits.
static
Matrix
<
float
>
&
GetLogits
()
{
static
auto
*
logits
=
new
Matrix
<
float
>
();
return
*
logits
;
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_PREDICTOR
(
CaptureLogits
);
class
SequenceModelTest
:
public
NetworkTestBase
{
protected:
// Adds default call expectations. Since these are added first, they can be
// overridden by call expectations in individual tests.
SequenceModelTest
()
{
EXPECT_CALL
(
compute_session_
,
GetInputBatchCache
())
.
WillRepeatedly
(
Return
(
&
input_
));
EXPECT_CALL
(
compute_session_
,
GetReadiedComponent
(
kTestComponentName
))
.
WillRepeatedly
(
Return
(
&
backend_
));
// Some tests overwrite these; ensure that they are restored to the normal
// values at the start of each test.
EvenNumbers
::
SetNumSteps
(
kNumSteps
);
LinkToPrevious
::
SetNumSteps
(
kNumSteps
);
CaptureLogits
::
GetLogits
()
=
Matrix
<
float
>
();
}
// Initializes the |model_| and its underlying feature managers from the
// |component_spec|, then uses the |model_| to preprocess and predict the
// |input_|. Also sets each row of the logits to twice its row index. On
// error, returns non-OK.
tensorflow
::
Status
Run
(
ComponentSpec
component_spec
)
{
component_spec
.
set_name
(
kTestComponentName
);
AddComponent
(
kPreviousComponentName
);
AddLayer
(
kPreviousLayerName
,
kLinkedDim
);
AddComponent
(
kTestComponentName
);
AddLayer
(
kLogitsName
,
kLogitsDim
);
TF_RETURN_IF_ERROR
(
fixed_embedding_manager_
.
Reset
(
component_spec
,
&
variable_store_
,
&
network_state_manager_
));
TF_RETURN_IF_ERROR
(
linked_embedding_manager_
.
Reset
(
component_spec
,
&
variable_store_
,
&
network_state_manager_
));
TF_RETURN_IF_ERROR
(
model_
.
Initialize
(
component_spec
,
kLogitsName
,
&
fixed_embedding_manager_
,
&
linked_embedding_manager_
,
&
network_state_manager_
));
network_states_
.
Reset
(
&
network_state_manager_
);
StartComponent
(
kNumSteps
);
FillLayer
(
kPreviousComponentName
,
kPreviousLayerName
,
kPreviousLayerValue
);
StartComponent
(
0
);
TF_RETURN_IF_ERROR
(
model_
.
Preprocess
(
&
session_state_
,
&
compute_session_
,
&
evaluate_state_
));
MutableMatrix
<
float
>
logits
=
GetLayer
(
kTestComponentName
,
kLogitsName
);
for
(
int
row
=
0
;
row
<
logits
.
num_rows
();
++
row
)
{
for
(
int
column
=
0
;
column
<
logits
.
num_columns
();
++
column
)
{
logits
.
row
(
row
)[
column
]
=
2.0
*
row
;
}
}
return
model_
.
Predict
(
network_states_
,
&
evaluate_state_
);
}
// Returns the sequence size passed to the |backend_|.
int
GetBackendSequenceSize
()
{
// The sequence size is not directly exposed, but can be inferred using one
// of the reverse step translators.
return
backend_
.
GetStepLookupFunction
(
"reverse-token"
)(
0
,
0
,
0
)
+
1
;
}
// Fixed and linked embedding managers.
FixedEmbeddingManager
fixed_embedding_manager_
;
LinkedEmbeddingManager
linked_embedding_manager_
;
// Input batch injected into Preprocess() by default.
InputBatchCache
input_
;
// Backend injected into Preprocess().
SequenceBackend
backend_
;
// Sequence-based model.
SequenceModel
model_
;
// Per-evaluation state.
SequenceModel
::
EvaluateState
evaluate_state_
;
};
// Returns a ComponentSpec that is supported.
ComponentSpec
MakeSupportedSpec
()
{
ComponentSpec
component_spec
;
component_spec
.
set_num_actions
(
kLogitsDim
);
component_spec
.
mutable_component_builder
()
->
mutable_parameters
()
->
insert
(
{
"sequence_extractors"
,
"EvenNumbers"
});
component_spec
.
mutable_component_builder
()
->
mutable_parameters
()
->
insert
(
{
"sequence_linkers"
,
"LinkToPrevious"
});
component_spec
.
mutable_component_builder
()
->
mutable_parameters
()
->
insert
(
{
"sequence_predictor"
,
"CaptureLogits"
});
component_spec
.
mutable_backend
()
->
set_registered_name
(
"SequenceBackend"
);
FixedFeatureChannel
*
fixed_feature
=
component_spec
.
add_fixed_feature
();
fixed_feature
->
set_size
(
1
);
fixed_feature
->
set_embedding_dim
(
-
1
);
LinkedFeatureChannel
*
linked_feature
=
component_spec
.
add_linked_feature
();
linked_feature
->
set_source_component
(
kPreviousComponentName
);
linked_feature
->
set_source_layer
(
kPreviousLayerName
);
linked_feature
->
set_size
(
1
);
linked_feature
->
set_embedding_dim
(
-
1
);
return
component_spec
;
}
// Tests that the model supports a supported spec.
TEST_F
(
SequenceModelTest
,
Supported
)
{
const
ComponentSpec
component_spec
=
MakeSupportedSpec
();
EXPECT_TRUE
(
SequenceModel
::
Supports
(
component_spec
));
}
// Tests that the model rejects a spec with the wrong backend.
TEST_F
(
SequenceModelTest
,
UnsupportedBackend
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
mutable_backend
()
->
set_registered_name
(
"bad"
);
EXPECT_FALSE
(
SequenceModel
::
Supports
(
component_spec
));
}
// Tests that the model rejects a spec with no features.
TEST_F
(
SequenceModelTest
,
UnsupportedNoFeatures
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
clear_fixed_feature
();
component_spec
.
clear_linked_feature
();
EXPECT_FALSE
(
SequenceModel
::
Supports
(
component_spec
));
}
// Tests that the model rejects a spec with a multi-embedding fixed feature.
TEST_F
(
SequenceModelTest
,
UnsupportedMultiEmbeddingFixedFeature
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
mutable_fixed_feature
(
0
)
->
set_size
(
2
);
EXPECT_FALSE
(
SequenceModel
::
Supports
(
component_spec
));
}
// Tests that the model rejects a spec with a multi-embedding linked feature.
TEST_F
(
SequenceModelTest
,
UnsupportedMultiEmbeddingLinkedFeature
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
mutable_linked_feature
(
0
)
->
set_size
(
2
);
EXPECT_FALSE
(
SequenceModel
::
Supports
(
component_spec
));
}
// Tests that the model rejects a spec with only recurrent links.
TEST_F
(
SequenceModelTest
,
UnsupportedOnlyRecurrentLinks
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
set_name
(
"foo"
);
component_spec
.
clear_fixed_feature
();
component_spec
.
mutable_linked_feature
(
0
)
->
set_source_component
(
"foo"
);
EXPECT_FALSE
(
SequenceModel
::
Supports
(
component_spec
));
}
// Tests that Initialize() succeeds on a supported spec.
TEST_F
(
SequenceModelTest
,
InitializeSupported
)
{
const
ComponentSpec
component_spec
=
MakeSupportedSpec
();
TF_ASSERT_OK
(
Run
(
component_spec
));
EXPECT_FALSE
(
model_
.
deterministic
());
EXPECT_TRUE
(
model_
.
left_to_right
());
EXPECT_EQ
(
model_
.
sequence_feature_manager
().
num_channels
(),
1
);
EXPECT_EQ
(
model_
.
sequence_link_manager
().
num_channels
(),
1
);
}
// Tests that Initialize() detects deterministic components.
TEST_F
(
SequenceModelTest
,
InitializeDeterministic
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
set_num_actions
(
1
);
TF_ASSERT_OK
(
Run
(
component_spec
));
EXPECT_TRUE
(
model_
.
deterministic
());
EXPECT_TRUE
(
model_
.
left_to_right
());
EXPECT_EQ
(
model_
.
sequence_feature_manager
().
num_channels
(),
1
);
EXPECT_EQ
(
model_
.
sequence_link_manager
().
num_channels
(),
1
);
}
// Tests that Initialize() detects right-to-left components.
TEST_F
(
SequenceModelTest
,
InitializeLeftToRight
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
mutable_transition_system
()
->
mutable_parameters
()
->
insert
(
{
"left_to_right"
,
"false"
});
TF_ASSERT_OK
(
Run
(
component_spec
));
EXPECT_FALSE
(
model_
.
deterministic
());
EXPECT_FALSE
(
model_
.
left_to_right
());
EXPECT_EQ
(
model_
.
sequence_feature_manager
().
num_channels
(),
1
);
EXPECT_EQ
(
model_
.
sequence_link_manager
().
num_channels
(),
1
);
}
// Tests that Initialize() fails if the backend is wrong.
TEST_F
(
SequenceModelTest
,
WrongBackend
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
mutable_backend
()
->
set_registered_name
(
"bad"
);
EXPECT_THAT
(
Run
(
component_spec
),
test
::
IsErrorWithSubstr
(
"Invalid component backend"
));
}
// Tests that Initialize() fails if the number of actions in the ComponentSpec
// does not match the logits.
TEST_F
(
SequenceModelTest
,
WrongNumActions
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
set_num_actions
(
kLogitsDim
+
1
);
EXPECT_THAT
(
Run
(
component_spec
),
test
::
IsErrorWithSubstr
(
"Logits dimension mismatch"
));
}
// Tests that Initialize() fails if an unknown sequence extractor is specified.
TEST_F
(
SequenceModelTest
,
UnknownSequenceExtractor
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
(
*
component_spec
.
mutable_component_builder
()
->
mutable_parameters
())[
"sequence_extractors"
]
=
"bad"
;
EXPECT_THAT
(
Run
(
component_spec
),
test
::
IsErrorWithSubstr
(
"Unknown DRAGNN Runtime Sequence Extractor"
));
}
// Tests that Initialize() fails if an unknown sequence linker is specified.
TEST_F
(
SequenceModelTest
,
UnknownSequenceLinker
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
(
*
component_spec
.
mutable_component_builder
()
->
mutable_parameters
())[
"sequence_linkers"
]
=
"bad"
;
EXPECT_THAT
(
Run
(
component_spec
),
test
::
IsErrorWithSubstr
(
"Unknown DRAGNN Runtime Sequence Linker"
));
}
// Tests that Initialize() fails if an unknown sequence predictor is specified.
TEST_F
(
SequenceModelTest
,
UnknownSequencePredictor
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
(
*
component_spec
.
mutable_component_builder
()
->
mutable_parameters
())[
"sequence_predictor"
]
=
"bad"
;
EXPECT_THAT
(
Run
(
component_spec
),
test
::
IsErrorWithSubstr
(
"Unknown DRAGNN Runtime Sequence Predictor"
));
}
// Tests that Initialize() fails on an unknown component builder parameter.
TEST_F
(
SequenceModelTest
,
UnknownComponentBuilderParameter
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
(
*
component_spec
.
mutable_component_builder
()
->
mutable_parameters
())[
"bad"
]
=
"bad"
;
EXPECT_THAT
(
Run
(
component_spec
),
test
::
IsErrorWithSubstr
(
"Unknown attribute"
));
}
// Tests that Initialize() fails if there are no fixed or linked features.
TEST_F
(
SequenceModelTest
,
InitializeRequiresFeatures
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
clear_fixed_feature
();
component_spec
.
clear_linked_feature
();
(
*
component_spec
.
mutable_component_builder
()
->
mutable_parameters
())[
"sequence_extractors"
]
=
""
;
(
*
component_spec
.
mutable_component_builder
()
->
mutable_parameters
())[
"sequence_linkers"
]
=
""
;
EXPECT_THAT
(
Run
(
component_spec
),
test
::
IsErrorWithSubstr
(
"No fixed or linked features"
));
}
// Tests that the model fails if a null batch is returned.
TEST_F
(
SequenceModelTest
,
NullBatch
)
{
EXPECT_CALL
(
compute_session_
,
GetInputBatchCache
()).
WillOnce
(
Return
(
nullptr
));
EXPECT_THAT
(
Run
(
MakeSupportedSpec
()),
test
::
IsErrorWithSubstr
(
"Null input batch"
));
}
// Tests that the model properly sets up the EvaluateState and logits.
TEST_F
(
SequenceModelTest
,
Success
)
{
TF_ASSERT_OK
(
Run
(
MakeSupportedSpec
()));
EXPECT_EQ
(
GetBackendSequenceSize
(),
kNumSteps
);
EXPECT_EQ
(
evaluate_state_
.
num_steps
,
kNumSteps
);
EXPECT_EQ
(
evaluate_state_
.
input
,
&
input_
);
EXPECT_EQ
(
evaluate_state_
.
features
.
num_channels
(),
1
);
EXPECT_EQ
(
evaluate_state_
.
features
.
num_steps
(),
kNumSteps
);
EXPECT_EQ
(
evaluate_state_
.
features
.
GetId
(
0
,
0
),
0
);
EXPECT_EQ
(
evaluate_state_
.
features
.
GetId
(
0
,
1
),
2
);
EXPECT_EQ
(
evaluate_state_
.
features
.
GetId
(
0
,
2
),
4
);
EXPECT_EQ
(
evaluate_state_
.
links
.
num_channels
(),
1
);
EXPECT_EQ
(
evaluate_state_
.
links
.
num_steps
(),
kNumSteps
);
Vector
<
float
>
embedding
;
bool
is_out_of_bounds
=
false
;
evaluate_state_
.
links
.
Get
(
0
,
0
,
&
embedding
,
&
is_out_of_bounds
);
ExpectVector
(
embedding
,
kLinkedDim
,
0.0
);
EXPECT_TRUE
(
is_out_of_bounds
);
evaluate_state_
.
links
.
Get
(
0
,
1
,
&
embedding
,
&
is_out_of_bounds
);
ExpectVector
(
embedding
,
kLinkedDim
,
kPreviousLayerValue
);
EXPECT_FALSE
(
is_out_of_bounds
);
const
Matrix
<
float
>
logits
=
CaptureLogits
::
GetLogits
();
ASSERT_EQ
(
logits
.
num_rows
(),
kNumSteps
);
ASSERT_EQ
(
logits
.
num_columns
(),
kLogitsDim
);
for
(
int
i
=
0
;
i
<
kNumSteps
;
++
i
)
{
ExpectVector
(
logits
.
row
(
i
),
kLogitsDim
,
2.0
*
i
);
}
}
// Tests that the model works with only fixed features.
TEST_F
(
SequenceModelTest
,
FixedFeaturesOnly
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
clear_linked_feature
();
(
*
component_spec
.
mutable_component_builder
()
->
mutable_parameters
())[
"sequence_linkers"
]
=
""
;
TF_ASSERT_OK
(
Run
(
component_spec
));
EXPECT_EQ
(
GetBackendSequenceSize
(),
kNumSteps
);
EXPECT_EQ
(
evaluate_state_
.
num_steps
,
kNumSteps
);
EXPECT_EQ
(
evaluate_state_
.
input
,
&
input_
);
EXPECT_EQ
(
evaluate_state_
.
features
.
num_channels
(),
1
);
EXPECT_EQ
(
evaluate_state_
.
features
.
num_steps
(),
kNumSteps
);
EXPECT_EQ
(
evaluate_state_
.
features
.
GetId
(
0
,
0
),
0
);
EXPECT_EQ
(
evaluate_state_
.
features
.
GetId
(
0
,
1
),
2
);
EXPECT_EQ
(
evaluate_state_
.
features
.
GetId
(
0
,
2
),
4
);
EXPECT_EQ
(
evaluate_state_
.
links
.
num_channels
(),
0
);
EXPECT_EQ
(
evaluate_state_
.
links
.
num_steps
(),
0
);
const
Matrix
<
float
>
logits
=
CaptureLogits
::
GetLogits
();
ASSERT_EQ
(
logits
.
num_rows
(),
kNumSteps
);
ASSERT_EQ
(
logits
.
num_columns
(),
kLogitsDim
);
for
(
int
i
=
0
;
i
<
kNumSteps
;
++
i
)
{
ExpectVector
(
logits
.
row
(
i
),
kLogitsDim
,
2.0
*
i
);
}
}
// Tests that the model works with only linked features.
TEST_F
(
SequenceModelTest
,
LinkedFeaturesOnly
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
clear_fixed_feature
();
(
*
component_spec
.
mutable_component_builder
()
->
mutable_parameters
())[
"sequence_extractors"
]
=
""
;
TF_ASSERT_OK
(
Run
(
component_spec
));
EXPECT_EQ
(
GetBackendSequenceSize
(),
kNumSteps
);
EXPECT_EQ
(
evaluate_state_
.
num_steps
,
kNumSteps
);
EXPECT_EQ
(
evaluate_state_
.
input
,
&
input_
);
EXPECT_EQ
(
evaluate_state_
.
features
.
num_channels
(),
0
);
EXPECT_EQ
(
evaluate_state_
.
features
.
num_steps
(),
0
);
EXPECT_EQ
(
evaluate_state_
.
links
.
num_channels
(),
1
);
EXPECT_EQ
(
evaluate_state_
.
links
.
num_steps
(),
kNumSteps
);
Vector
<
float
>
embedding
;
bool
is_out_of_bounds
=
false
;
evaluate_state_
.
links
.
Get
(
0
,
0
,
&
embedding
,
&
is_out_of_bounds
);
ExpectVector
(
embedding
,
kLinkedDim
,
0.0
);
EXPECT_TRUE
(
is_out_of_bounds
);
evaluate_state_
.
links
.
Get
(
0
,
1
,
&
embedding
,
&
is_out_of_bounds
);
ExpectVector
(
embedding
,
kLinkedDim
,
kPreviousLayerValue
);
EXPECT_FALSE
(
is_out_of_bounds
);
const
Matrix
<
float
>
logits
=
CaptureLogits
::
GetLogits
();
ASSERT_EQ
(
logits
.
num_rows
(),
kNumSteps
);
ASSERT_EQ
(
logits
.
num_columns
(),
kLogitsDim
);
for
(
int
i
=
0
;
i
<
kNumSteps
;
++
i
)
{
ExpectVector
(
logits
.
row
(
i
),
kLogitsDim
,
2.0
*
i
);
}
}
// Tests that the model fails if the fixed and linked features disagree on the
// number of steps.
TEST_F
(
SequenceModelTest
,
FixedAndLinkedDisagree
)
{
EvenNumbers
::
SetNumSteps
(
5
);
LinkToPrevious
::
SetNumSteps
(
6
);
EXPECT_THAT
(
Run
(
MakeSupportedSpec
()),
test
::
IsErrorWithSubstr
(
"Sequence length mismatch between fixed "
"features (5) and linked features (6)"
));
}
// Tests that the model can handle an empty sequence.
TEST_F
(
SequenceModelTest
,
EmptySequence
)
{
EvenNumbers
::
SetNumSteps
(
0
);
LinkToPrevious
::
SetNumSteps
(
0
);
TF_ASSERT_OK
(
Run
(
MakeSupportedSpec
()));
EXPECT_EQ
(
GetBackendSequenceSize
(),
0
);
const
Matrix
<
float
>
logits
=
CaptureLogits
::
GetLogits
();
ASSERT_EQ
(
logits
.
num_rows
(),
0
);
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/sequence_predictor.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_predictor.h"
#include "tensorflow/core/lib/core/errors.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
tensorflow
::
Status
SequencePredictor
::
Select
(
const
ComponentSpec
&
component_spec
,
string
*
name
)
{
string
supporting_name
;
for
(
const
Registry
::
Registrar
*
registrar
=
registry
()
->
components
;
registrar
!=
nullptr
;
registrar
=
registrar
->
next
())
{
Factory
*
factory_function
=
registrar
->
object
();
std
::
unique_ptr
<
SequencePredictor
>
current_predictor
(
factory_function
());
if
(
!
current_predictor
->
Supports
(
component_spec
))
continue
;
if
(
!
supporting_name
.
empty
())
{
return
tensorflow
::
errors
::
Internal
(
"Multiple SequencePredictors support ComponentSpec ("
,
supporting_name
,
" and "
,
registrar
->
name
(),
"): "
,
component_spec
.
ShortDebugString
());
}
supporting_name
=
registrar
->
name
();
}
if
(
supporting_name
.
empty
())
{
return
tensorflow
::
errors
::
NotFound
(
"No SequencePredictor supports ComponentSpec: "
,
component_spec
.
ShortDebugString
());
}
// Success; make modifications.
*
name
=
supporting_name
;
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
SequencePredictor
::
New
(
const
string
&
name
,
const
ComponentSpec
&
component_spec
,
std
::
unique_ptr
<
SequencePredictor
>
*
predictor
)
{
std
::
unique_ptr
<
SequencePredictor
>
matching_predictor
;
TF_RETURN_IF_ERROR
(
SequencePredictor
::
CreateOrError
(
name
,
&
matching_predictor
));
TF_RETURN_IF_ERROR
(
matching_predictor
->
Initialize
(
component_spec
));
// Success; make modifications.
*
predictor
=
std
::
move
(
matching_predictor
);
return
tensorflow
::
Status
::
OK
();
}
}
// namespace runtime
}
// namespace dragnn
REGISTER_SYNTAXNET_CLASS_REGISTRY
(
"DRAGNN Runtime Sequence Predictor"
,
dragnn
::
runtime
::
SequencePredictor
);
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/sequence_predictor.h
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_SEQUENCE_PREDICTOR_H_
#define DRAGNN_RUNTIME_SEQUENCE_PREDICTOR_H_
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/math/types.h"
#include "syntaxnet/base.h"
#include "syntaxnet/registry.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// Interface for making predictions on sequences.
//
// This predictor can be used to avoid ComputeSession overhead in simple cases;
// for example, predicting sequences of POS tags.
class
SequencePredictor
:
public
RegisterableClass
<
SequencePredictor
>
{
public:
// Sets |predictor| to an instance of the subclass named |name| initialized
// from the |component_spec|. On error, returns non-OK and modifies nothing.
static
tensorflow
::
Status
New
(
const
string
&
name
,
const
ComponentSpec
&
component_spec
,
std
::
unique_ptr
<
SequencePredictor
>
*
predictor
);
SequencePredictor
(
const
SequencePredictor
&
)
=
delete
;
SequencePredictor
&
operator
=
(
const
SequencePredictor
&
)
=
delete
;
virtual
~
SequencePredictor
()
=
default
;
// Sets |name| to the registered name of the SequencePredictor that supports
// the |component_spec|. On error, returns non-OK and modifies nothing. The
// returned statuses include:
// * OK: If a supporting SequencePredictor was found.
// * INTERNAL: If an error occurred while searching for a compatible match.
// * NOT_FOUND: If the search was error-free, but no compatible match was
// found.
static
tensorflow
::
Status
Select
(
const
ComponentSpec
&
component_spec
,
string
*
name
);
// Makes a sequence of predictions using the per-step |logits| and writes
// annotations to the |input|.
virtual
tensorflow
::
Status
Predict
(
Matrix
<
float
>
logits
,
InputBatchCache
*
input
)
const
=
0
;
protected:
SequencePredictor
()
=
default
;
private:
// Helps prevent use of the Create() method; use New() instead.
using
RegisterableClass
<
SequencePredictor
>::
Create
;
// Returns true if this supports the |component_spec|. Implementations must
// coordinate to ensure that at most one supports any given |component_spec|.
virtual
bool
Supports
(
const
ComponentSpec
&
component_spec
)
const
=
0
;
// Initializes this from the |component_spec|. On error, returns non-OK.
virtual
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
component_spec
)
=
0
;
};
}
// namespace runtime
}
// namespace dragnn
DECLARE_SYNTAXNET_CLASS_REGISTRY
(
"DRAGNN Runtime Sequence Predictor"
,
dragnn
::
runtime
::
SequencePredictor
);
}
// namespace syntaxnet
#define DRAGNN_RUNTIME_REGISTER_SEQUENCE_PREDICTOR(subclass) \
REGISTER_SYNTAXNET_CLASS_COMPONENT( \
::syntaxnet::dragnn::runtime::SequencePredictor, #subclass, subclass)
#endif // DRAGNN_RUNTIME_SEQUENCE_PREDICTOR_H_
research/syntaxnet/dragnn/runtime/sequence_predictor_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_predictor.h"
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/math/types.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Supports components named "success" and initializes successfully.
class
Success
:
public
SequencePredictor
{
public:
// Implements SequencePredictor.
bool
Supports
(
const
ComponentSpec
&
component_spec
)
const
override
{
return
component_spec
.
name
()
==
"success"
;
}
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
Predict
(
Matrix
<
float
>
,
InputBatchCache
*
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_PREDICTOR
(
Success
);
// Supports components named "failure" and fails to initialize.
class
Failure
:
public
SequencePredictor
{
public:
// Implements SequencePredictor.
bool
Supports
(
const
ComponentSpec
&
component_spec
)
const
override
{
return
component_spec
.
name
()
==
"failure"
;
}
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
)
override
{
return
tensorflow
::
errors
::
Internal
(
"Boom!"
);
}
tensorflow
::
Status
Predict
(
Matrix
<
float
>
,
InputBatchCache
*
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_PREDICTOR
(
Failure
);
// Supports components named "duplicate" and initializes successfully.
class
Duplicate
:
public
SequencePredictor
{
public:
// Implements SequencePredictor.
bool
Supports
(
const
ComponentSpec
&
component_spec
)
const
override
{
return
component_spec
.
name
()
==
"duplicate"
;
}
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
Predict
(
Matrix
<
float
>
,
InputBatchCache
*
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_PREDICTOR
(
Duplicate
);
// Duplicate of the above.
using
Duplicate2
=
Duplicate
;
DRAGNN_RUNTIME_REGISTER_SEQUENCE_PREDICTOR
(
Duplicate2
);
// Tests that a component can be successfully created.
TEST
(
SequencePredictorTest
,
Success
)
{
string
name
;
std
::
unique_ptr
<
SequencePredictor
>
predictor
;
ComponentSpec
component_spec
;
component_spec
.
set_name
(
"success"
);
TF_ASSERT_OK
(
SequencePredictor
::
Select
(
component_spec
,
&
name
));
ASSERT_EQ
(
name
,
"Success"
);
TF_EXPECT_OK
(
SequencePredictor
::
New
(
name
,
component_spec
,
&
predictor
));
EXPECT_NE
(
predictor
,
nullptr
);
}
// Tests that errors in Initialize() are reported.
TEST
(
SequencePredictorTest
,
FailToInitialize
)
{
string
name
;
std
::
unique_ptr
<
SequencePredictor
>
predictor
;
ComponentSpec
component_spec
;
component_spec
.
set_name
(
"failure"
);
TF_ASSERT_OK
(
SequencePredictor
::
Select
(
component_spec
,
&
name
));
EXPECT_EQ
(
name
,
"Failure"
);
EXPECT_THAT
(
SequencePredictor
::
New
(
name
,
component_spec
,
&
predictor
),
test
::
IsErrorWithSubstr
(
"Boom!"
));
EXPECT_EQ
(
predictor
,
nullptr
);
}
// Tests that unsupported specs are reported as NOT_FOUND errors.
TEST
(
SequencePredictorTest
,
UnsupportedSpec
)
{
string
name
=
"not overwritten"
;
ComponentSpec
component_spec
;
component_spec
.
set_name
(
"unsupported"
);
EXPECT_THAT
(
SequencePredictor
::
Select
(
component_spec
,
&
name
),
test
::
IsErrorWithCodeAndSubstr
(
tensorflow
::
error
::
NOT_FOUND
,
"No SequencePredictor supports ComponentSpec"
));
EXPECT_EQ
(
name
,
"not overwritten"
);
}
// Tests that unsupported subclass names are reported as errors.
TEST
(
SequencePredictorTest
,
UnsupportedSubclass
)
{
std
::
unique_ptr
<
SequencePredictor
>
predictor
;
ComponentSpec
component_spec
;
EXPECT_THAT
(
SequencePredictor
::
New
(
"Unsupported"
,
component_spec
,
&
predictor
),
test
::
IsErrorWithSubstr
(
"Unknown DRAGNN Runtime Sequence Predictor"
));
EXPECT_EQ
(
predictor
,
nullptr
);
}
// Tests that multiple supporting predictors are reported as INTERNAL errors.
TEST
(
SequencePredictorTest
,
Duplicate
)
{
string
name
=
"not overwritten"
;
ComponentSpec
component_spec
;
component_spec
.
set_name
(
"duplicate"
);
EXPECT_THAT
(
SequencePredictor
::
Select
(
component_spec
,
&
name
),
test
::
IsErrorWithCodeAndSubstr
(
tensorflow
::
error
::
INTERNAL
,
"Multiple SequencePredictors support ComponentSpec"
));
EXPECT_EQ
(
name
,
"not overwritten"
);
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/session_state.h
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_SESSION_STATE_H_
#define DRAGNN_RUNTIME_SESSION_STATE_H_
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/network_states.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// State associated with a ComputeSession being evaluated by a DRAGNN network,
// reusable across multiple evaluations. Unlike the ComputeSession, which is
// both the input and output of the network, this state is strictly internal to
// the network. Production code should allocate these via a SessionStatePool.
struct
SessionState
{
// The network states that connect the pipeline of components.
NetworkStates
network_states
;
// Generic set of typed extensions.
Extensions
extensions
;
};
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_SESSION_STATE_H_
research/syntaxnet/dragnn/runtime/session_state_pool.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/session_state_pool.h"
#include <algorithm>
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
SessionStatePool
::
SessionStatePool
(
size_t
max_free_states
)
:
max_free_states_
(
max_free_states
)
{}
std
::
unique_ptr
<
SessionState
>
SessionStatePool
::
Acquire
()
{
{
// Exclude the slow path from the critical region.
tensorflow
::
mutex_lock
lock
(
mutex_
);
if
(
!
free_list_
.
empty
())
{
// Fast path: reuse a free state.
std
::
unique_ptr
<
SessionState
>
state
=
std
::
move
(
free_list_
.
back
());
free_list_
.
pop_back
();
return
state
;
}
}
// Slow path: allocate a new state.
return
std
::
unique_ptr
<
SessionState
>
(
new
SessionState
());
}
void
SessionStatePool
::
Release
(
std
::
unique_ptr
<
SessionState
>
state
)
{
{
// Exclude the slow path from the critical region.
tensorflow
::
mutex_lock
lock
(
mutex_
);
if
(
free_list_
.
size
()
<
max_free_states_
)
{
// Fast path: reclaim in the free list.
free_list_
.
emplace_back
(
std
::
move
(
state
));
return
;
}
}
// Slow path: discard the excess |state| when it goes out of scope.
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/session_state_pool.h
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_SESSION_STATE_POOL_H_
#define DRAGNN_RUNTIME_SESSION_STATE_POOL_H_
#include <stddef.h>
#include <memory>
#include <utility>
#include "dragnn/runtime/session_state.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// A thread-safe pool of session states that maintains a free list. The free
// list is bounded, so a spike in usage does not permanently increase the size
// of the pool. Use ScopedSessionState to interact with the pool.
class
SessionStatePool
{
public:
// Creates a pool whose free list holds at most |max_free_states| states.
//
// If usage spikes are not a concern (e.g., during offline processing where
// the runtime is called from a fixed-size pool of threads), then specify a
// large value like SIZE_MAX. That eliminates unnecessary deallocations and
// reallocations, and eliminates the need to coordinate the thread pool size
// with this pool's size.
//
// If memory usage dominates CPU usage, then specify 0 to eliminate overhead
// from the free list.
//
// TODO(googleuser): An alternative is to set a target allocation
// rate (e.g., 2% of Acquire()s should create a new state), and let the pool
// adapt its free list size to achieve that rate.
explicit
SessionStatePool
(
size_t
max_free_states
);
private:
friend
class
ScopedSessionState
;
// Returns a state acquired from this pool. The caller is the exclusive user
// of the returned state until it is passed to Release().
std
::
unique_ptr
<
SessionState
>
Acquire
();
// Releases the |state| back to this pool. The |state| must be the result of
// a previous Acquire(). The caller can no longer use the |state|.
void
Release
(
std
::
unique_ptr
<
SessionState
>
state
);
// Maximum number of states to keep in the |free_list_|.
const
size_t
max_free_states_
;
// Mutex guarding the |free_list_|.
tensorflow
::
mutex
mutex_
;
// List of previously-Release()d states.
std
::
vector
<
std
::
unique_ptr
<
SessionState
>>
free_list_
GUARDED_BY
(
mutex_
);
};
// RAII wrapper that manages a session state acquired from a pool. The wrapped
// state is usable during the lifetime of the wrapper.
class
ScopedSessionState
{
public:
// Implements RAII semantics.
explicit
ScopedSessionState
(
SessionStatePool
*
pool
)
:
pool_
(
pool
),
state_
(
pool_
->
Acquire
())
{}
~
ScopedSessionState
()
{
pool_
->
Release
(
std
::
move
(
state_
));
}
// Prevents double-release.
ScopedSessionState
(
const
ScopedSessionState
&
that
)
=
delete
;
ScopedSessionState
&
operator
=
(
const
ScopedSessionState
&
that
)
=
delete
;
// Provides std::unique_ptr-like access.
SessionState
*
get
()
const
{
return
state_
.
get
();
}
SessionState
&
operator
*
()
const
{
return
*
get
();
}
SessionState
*
operator
->
()
const
{
return
get
();
}
private:
// Pool from which the |state_| was acquired.
SessionStatePool
*
const
pool_
;
// Wrapped session state.
std
::
unique_ptr
<
SessionState
>
state_
;
};
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_SESSION_STATE_POOL_H_
research/syntaxnet/dragnn/runtime/session_state_pool_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/session_state_pool.h"
#include <stddef.h>
#include <set>
#include "dragnn/runtime/session_state.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Maximum number of free states.
static
constexpr
size_t
kMaxFreeStates
=
16
;
class
SessionStatePoolTest
:
public
::
testing
::
Test
{
protected:
SessionStatePool
pool_
{
kMaxFreeStates
};
};
// Tests that ScopedSessionState can be used to acquire a valid state.
TEST_F
(
SessionStatePoolTest
,
ScopedWrapper
)
{
const
ScopedSessionState
state
(
&
pool_
);
EXPECT_TRUE
(
state
.
get
());
// non-null
}
// Tests that the active states claimed from the pool are unique.
TEST_F
(
SessionStatePoolTest
,
UniqueActiveStates
)
{
// NB: Don't use std::unique_ptr<ScopedSessionState> in real code. The test
// does this because it's otherwise difficult to acquire lots of states.
std
::
vector
<
std
::
unique_ptr
<
ScopedSessionState
>>
states
;
for
(
size_t
i
=
0
;
i
<
100
;
++
i
)
{
states
.
emplace_back
(
new
ScopedSessionState
(
&
pool_
));
}
// Check that all of the states are unique.
std
::
set
<
const
SessionState
*>
state_ptrs
;
for
(
const
auto
&
state
:
states
)
{
EXPECT_TRUE
(
state_ptrs
.
insert
(
state
->
get
()).
second
);
}
EXPECT_TRUE
(
state_ptrs
.
find
(
nullptr
)
==
state_ptrs
.
end
());
}
// Tests that active states, when released, are reclaimed and reused.
TEST_F
(
SessionStatePoolTest
,
Reuse
)
{
std
::
set
<
const
SessionState
*>
state_ptrs
;
{
// Grab exactly as many states as the free list can hold.
std
::
vector
<
std
::
unique_ptr
<
ScopedSessionState
>>
states
;
for
(
size_t
i
=
0
;
i
<
kMaxFreeStates
;
++
i
)
{
states
.
emplace_back
(
new
ScopedSessionState
(
&
pool_
));
EXPECT_TRUE
(
state_ptrs
.
insert
(
states
.
back
()
->
get
()).
second
);
}
}
{
// Grab the same number of states again and check that they are the same
// objects we saw in the first loop.
std
::
vector
<
std
::
unique_ptr
<
ScopedSessionState
>>
states
;
for
(
size_t
i
=
0
;
i
<
kMaxFreeStates
;
++
i
)
{
states
.
emplace_back
(
new
ScopedSessionState
(
&
pool_
));
EXPECT_FALSE
(
state_ptrs
.
insert
(
states
.
back
()
->
get
()).
second
);
}
}
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/stateless_component_transformer.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <string>
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component_transformation.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Returns true if the |component_type| can be transformed by this.
bool
ShouldTransform
(
const
string
&
component_type
)
{
for
(
const
char
*
supported_type
:
{
"SyntaxNetHeadSelectionComponent"
,
//
"SyntaxNetMstSolverComponent"
,
//
})
{
if
(
component_type
==
supported_type
)
return
true
;
}
return
false
;
}
// Changes the backend for some components to StatelessComponent.
class
StatelessComponentTransformer
:
public
ComponentTransformer
{
public:
// Implements ComponentTransformer.
tensorflow
::
Status
Transform
(
const
string
&
component_type
,
ComponentSpec
*
component_spec
)
override
{
if
(
ShouldTransform
(
component_type
))
{
component_spec
->
mutable_backend
()
->
set_registered_name
(
"StatelessComponent"
);
}
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_COMPONENT_TRANSFORMER
(
StatelessComponentTransformer
);
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/stateless_component_transformer_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component_transformation.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Arbitrary supported component type.
constexpr
char
kSupportedComponentType
[]
=
"SyntaxNetHeadSelectionComponent"
;
// Returns a ComponentSpec that is supported by the transformer.
ComponentSpec
MakeSupportedSpec
()
{
ComponentSpec
component_spec
;
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
kSupportedComponentType
);
return
component_spec
;
}
// Tests that a compatible spec is modified to use StatelessComponent.
TEST
(
StatelessComponentTransformerTest
,
Compatible
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
ComponentSpec
expected_spec
=
component_spec
;
expected_spec
.
mutable_backend
()
->
set_registered_name
(
"StatelessComponent"
);
TF_ASSERT_OK
(
ComponentTransformer
::
ApplyAll
(
&
component_spec
));
EXPECT_THAT
(
component_spec
,
test
::
EqualsProto
(
expected_spec
));
}
// Tests that other component specs are not modified.
TEST
(
StatelessComponentTransformerTest
,
Incompatible
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"other"
);
const
ComponentSpec
expected_spec
=
component_spec
;
TF_ASSERT_OK
(
ComponentTransformer
::
ApplyAll
(
&
component_spec
));
EXPECT_THAT
(
component_spec
,
test
::
EqualsProto
(
expected_spec
));
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/syntaxnet_character_sequence_extractor.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/io/sentence_input_batch.h"
#include "dragnn/io/syntaxnet_sentence.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/sequence_extractor.h"
#include "dragnn/runtime/term_map_sequence_extractor.h"
#include "dragnn/runtime/term_map_utils.h"
#include "dragnn/runtime/transition_system_traits.h"
#include "dragnn/runtime/unicode_dictionary.h"
#include "syntaxnet/base.h"
#include "syntaxnet/segmenter_utils.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "util/utf8/unicodetext.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Sequence extractor that extracts characters from a SyntaxNetComponent batch.
class
SyntaxNetCharacterSequenceExtractor
:
public
TermMapSequenceExtractor
<
UnicodeDictionary
>
{
public:
SyntaxNetCharacterSequenceExtractor
();
// Implements SequenceExtractor.
bool
Supports
(
const
FixedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
)
const
override
;
tensorflow
::
Status
Initialize
(
const
FixedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
)
override
;
tensorflow
::
Status
GetIds
(
InputBatchCache
*
input
,
std
::
vector
<
int32
>
*
ids
)
const
override
;
private:
// Parses |fml| and sets |min_frequency| and |max_num_terms| to the specified
// values. If the |fml| does not specify a supported feature, returns non-OK
// and modifies nothing.
static
tensorflow
::
Status
ParseFml
(
const
string
&
fml
,
int
*
min_frequency
,
int
*
max_num_terms
);
// Feature IDs for break characters and unknown characters.
int32
break_id_
=
-
1
;
int32
unknown_id_
=
-
1
;
};
SyntaxNetCharacterSequenceExtractor
::
SyntaxNetCharacterSequenceExtractor
()
:
TermMapSequenceExtractor
(
"char-map"
)
{}
tensorflow
::
Status
SyntaxNetCharacterSequenceExtractor
::
ParseFml
(
const
string
&
fml
,
int
*
min_frequency
,
int
*
max_num_terms
)
{
return
ParseTermMapFml
(
fml
,
{
"char-input"
,
"text-char"
},
min_frequency
,
max_num_terms
);
}
bool
SyntaxNetCharacterSequenceExtractor
::
Supports
(
const
FixedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
)
const
{
TransitionSystemTraits
traits
(
component_spec
);
int
unused_min_frequency
=
0
;
int
unused_max_num_terms
=
0
;
const
tensorflow
::
Status
parse_fml_status
=
ParseFml
(
channel
.
fml
(),
&
unused_min_frequency
,
&
unused_max_num_terms
);
return
TermMapSequenceExtractor
::
SupportsTermMap
(
channel
,
component_spec
)
&&
parse_fml_status
.
ok
()
&&
component_spec
.
backend
().
registered_name
()
==
"SyntaxNetComponent"
&&
traits
.
is_sequential
&&
traits
.
is_character_scale
;
}
tensorflow
::
Status
SyntaxNetCharacterSequenceExtractor
::
Initialize
(
const
FixedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
)
{
int
min_frequency
=
0
;
int
max_num_terms
=
0
;
TF_RETURN_IF_ERROR
(
ParseFml
(
channel
.
fml
(),
&
min_frequency
,
&
max_num_terms
));
TF_RETURN_IF_ERROR
(
TermMapSequenceExtractor
::
InitializeTermMap
(
channel
,
component_spec
,
min_frequency
,
max_num_terms
));
const
int
num_known
=
term_map
().
size
();
break_id_
=
num_known
;
unknown_id_
=
break_id_
+
1
;
const
int
map_vocab_size
=
unknown_id_
+
1
;
const
int
spec_vocab_size
=
channel
.
vocabulary_size
();
if
(
map_vocab_size
!=
spec_vocab_size
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Character vocabulary size mismatch between term map ("
,
map_vocab_size
,
") and ComponentSpec ("
,
spec_vocab_size
,
")"
);
}
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
SyntaxNetCharacterSequenceExtractor
::
GetIds
(
InputBatchCache
*
input
,
std
::
vector
<
int32
>
*
ids
)
const
{
ids
->
clear
();
const
std
::
vector
<
SyntaxNetSentence
>
&
data
=
*
input
->
GetAs
<
SentenceInputBatch
>
()
->
data
();
if
(
data
.
size
()
!=
1
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Non-singleton batch: got "
,
data
.
size
(),
" elements"
);
}
const
Sentence
&
sentence
=
*
data
[
0
].
sentence
();
if
(
sentence
.
token_size
()
==
0
)
return
tensorflow
::
Status
::
OK
();
const
string
&
text
=
sentence
.
text
();
const
int
start_byte
=
sentence
.
token
(
0
).
start
();
const
int
end_byte
=
sentence
.
token
(
sentence
.
token_size
()
-
1
).
end
();
const
int
num_bytes
=
end_byte
-
start_byte
+
1
;
string
character
;
UnicodeText
unicode_text
;
unicode_text
.
PointToUTF8
(
text
.
data
()
+
start_byte
,
num_bytes
);
const
auto
end
=
unicode_text
.
end
();
for
(
auto
it
=
unicode_text
.
begin
();
it
!=
end
;
++
it
)
{
character
.
assign
(
it
.
utf8_data
(),
it
.
utf8_length
());
if
(
SegmenterUtils
::
IsBreakChar
(
character
))
{
ids
->
push_back
(
break_id_
);
}
else
{
ids
->
push_back
(
term_map
().
Lookup
(
character
.
data
(),
character
.
size
(),
unknown_id_
));
}
}
return
tensorflow
::
Status
::
OK
();
}
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR
(
SyntaxNetCharacterSequenceExtractor
);
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/syntaxnet_character_sequence_extractor_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/sequence_extractor.h"
#include "dragnn/runtime/test/term_map_helpers.h"
#include "syntaxnet/base.h"
#include "syntaxnet/sentence.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
constexpr
char
kResourceName
[]
=
"char-map"
;
// Returns a ComponentSpec parsed from the |text| that contains a term map
// resource pointing at the |path|.
ComponentSpec
MakeSpec
(
const
string
&
text
,
const
string
&
path
)
{
ComponentSpec
component_spec
;
CHECK
(
TextFormat
::
ParseFromString
(
text
,
&
component_spec
));
AddTermMapResource
(
kResourceName
,
path
,
&
component_spec
);
return
component_spec
;
}
// Returns a supported ComponentSpec that points at the term map in the |path|.
ComponentSpec
MakeSupportedSpec
(
const
string
&
path
=
"/dev/null"
)
{
return
MakeSpec
(
R"(transition_system { registered_name: 'char-shift-only' }
backend { registered_name: 'SyntaxNetComponent' }
fixed_feature {} # breaks hard-coded refs to channel 0
fixed_feature { size: 1 fml: 'char-input.text-char' })"
,
path
);
}
// Returns a default sentence.
Sentence
MakeSentence
()
{
Sentence
sentence
;
sentence
.
set_text
(
"a bc def"
);
Token
*
token
=
sentence
.
add_token
();
token
->
set_start
(
0
);
token
->
set_end
(
sentence
.
text
().
size
()
-
1
);
token
->
set_word
(
sentence
.
text
());
return
sentence
;
}
// Tests that the extractor supports an appropriate spec.
TEST
(
SyntaxNetCharacterSequenceExtractorTest
,
Supported
)
{
const
ComponentSpec
component_spec
=
MakeSupportedSpec
();
const
FixedFeatureChannel
&
channel
=
component_spec
.
fixed_feature
(
1
);
string
name
;
TF_ASSERT_OK
(
SequenceExtractor
::
Select
(
channel
,
component_spec
,
&
name
));
EXPECT_EQ
(
name
,
"SyntaxNetCharacterSequenceExtractor"
);
}
// Tests that the extractor requires the proper backend.
TEST
(
SyntaxNetCharacterSequenceExtractorTest
,
WrongBackend
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
mutable_backend
()
->
set_registered_name
(
"bad"
);
const
FixedFeatureChannel
&
channel
=
component_spec
.
fixed_feature
(
1
);
string
name
;
EXPECT_THAT
(
SequenceExtractor
::
Select
(
channel
,
component_spec
,
&
name
),
test
::
IsErrorWithSubstr
(
"No SequenceExtractor supports channel"
));
}
// Tests that the extractor requires the proper transition system.
TEST
(
SyntaxNetCharacterSequenceExtractorTest
,
WrongTransitionSystem
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
mutable_transition_system
()
->
set_registered_name
(
"bad"
);
const
FixedFeatureChannel
&
channel
=
component_spec
.
fixed_feature
(
1
);
string
name
;
EXPECT_THAT
(
SequenceExtractor
::
Select
(
channel
,
component_spec
,
&
name
),
test
::
IsErrorWithSubstr
(
"No SequenceExtractor supports channel"
));
}
// Tests that the extractor requires the proper FML.
TEST
(
SyntaxNetCharacterSequenceExtractorTest
,
WrongFml
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
mutable_fixed_feature
(
1
)
->
set_fml
(
"bad"
);
const
FixedFeatureChannel
&
channel
=
component_spec
.
fixed_feature
(
1
);
string
name
;
EXPECT_THAT
(
SequenceExtractor
::
Select
(
channel
,
component_spec
,
&
name
),
test
::
IsErrorWithSubstr
(
"No SequenceExtractor supports channel"
));
}
// Tests that the extractor can be initialized and used to extract feature IDs.
TEST
(
SyntaxNetCharacterSequenceExtractorTest
,
InitializeAndGetIds
)
{
// Terms are sorted by descending frequency, so this ensures a=0, b=1, etc.
const
string
path
=
WriteTermMap
({{
"a"
,
5
},
{
"b"
,
4
},
{
"c"
,
3
},
{
"d"
,
2
},
{
"e"
,
1
}});
ComponentSpec
component_spec
=
MakeSupportedSpec
(
path
);
FixedFeatureChannel
&
channel
=
*
component_spec
.
mutable_fixed_feature
(
1
);
channel
.
set_vocabulary_size
(
7
);
std
::
unique_ptr
<
SequenceExtractor
>
extractor
;
TF_ASSERT_OK
(
SequenceExtractor
::
New
(
"SyntaxNetCharacterSequenceExtractor"
,
channel
,
component_spec
,
&
extractor
));
const
Sentence
sentence
=
MakeSentence
();
InputBatchCache
input
(
sentence
.
SerializeAsString
());
std
::
vector
<
int32
>
ids
;
TF_ASSERT_OK
(
extractor
->
GetIds
(
&
input
,
&
ids
));
// 0-4 = 'a' to 'e'
// 5 = break chars (whitespace)
// 6 = unknown chars (e.g., 'f')
const
std
::
vector
<
int32
>
expected_ids
=
{
0
,
5
,
1
,
2
,
5
,
3
,
4
,
6
};
EXPECT_EQ
(
ids
,
expected_ids
);
}
// Tests that an empty term map works.
TEST
(
SyntaxNetCharacterSequenceExtractorTest
,
EmptyTermMap
)
{
const
string
path
=
WriteTermMap
({});
ComponentSpec
component_spec
=
MakeSupportedSpec
(
path
);
FixedFeatureChannel
&
channel
=
*
component_spec
.
mutable_fixed_feature
(
1
);
channel
.
set_vocabulary_size
(
2
);
std
::
unique_ptr
<
SequenceExtractor
>
extractor
;
TF_ASSERT_OK
(
SequenceExtractor
::
New
(
"SyntaxNetCharacterSequenceExtractor"
,
channel
,
component_spec
,
&
extractor
));
const
Sentence
sentence
=
MakeSentence
();
InputBatchCache
input
(
sentence
.
SerializeAsString
());
std
::
vector
<
int32
>
ids
=
{
1
,
2
,
3
,
4
};
// should be overwritten
TF_ASSERT_OK
(
extractor
->
GetIds
(
&
input
,
&
ids
));
const
std
::
vector
<
int32
>
expected_ids
=
{
1
,
0
,
1
,
1
,
0
,
1
,
1
,
1
};
EXPECT_EQ
(
ids
,
expected_ids
);
}
// Tests that GetIds() fails if the batch is the wrong size.
TEST
(
SyntaxNetCharacterSequenceExtractorTest
,
WrongBatchSize
)
{
const
string
path
=
WriteTermMap
({});
ComponentSpec
component_spec
=
MakeSupportedSpec
(
path
);
FixedFeatureChannel
&
channel
=
*
component_spec
.
mutable_fixed_feature
(
1
);
channel
.
set_vocabulary_size
(
2
);
std
::
unique_ptr
<
SequenceExtractor
>
extractor
;
TF_ASSERT_OK
(
SequenceExtractor
::
New
(
"SyntaxNetCharacterSequenceExtractor"
,
channel
,
component_spec
,
&
extractor
));
const
Sentence
sentence
=
MakeSentence
();
const
std
::
vector
<
string
>
data
=
{
sentence
.
SerializeAsString
(),
sentence
.
SerializeAsString
()};
InputBatchCache
input
(
data
);
std
::
vector
<
int32
>
ids
;
EXPECT_THAT
(
extractor
->
GetIds
(
&
input
,
&
ids
),
test
::
IsErrorWithSubstr
(
"Non-singleton batch: got 2 elements"
));
}
// Tests that initialization fails if the vocabulary size does not match.
TEST
(
SyntaxNetCharacterSequenceExtractorTest
,
WrongVocabularySize
)
{
const
string
path
=
WriteTermMap
({});
ComponentSpec
component_spec
=
MakeSupportedSpec
(
path
);
FixedFeatureChannel
&
channel
=
*
component_spec
.
mutable_fixed_feature
(
1
);
channel
.
set_vocabulary_size
(
1000
);
std
::
unique_ptr
<
SequenceExtractor
>
extractor
;
EXPECT_THAT
(
SequenceExtractor
::
New
(
"SyntaxNetCharacterSequenceExtractor"
,
channel
,
component_spec
,
&
extractor
),
test
::
IsErrorWithSubstr
(
"Character vocabulary size mismatch between term "
"map (2) and ComponentSpec (1000)"
));
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/syntaxnet_character_sequence_linkers.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/io/sentence_input_batch.h"
#include "dragnn/io/syntaxnet_sentence.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/sequence_linker.h"
#include "dragnn/runtime/transition_system_traits.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "util/utf8/unilib_utf8_utils.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Focus character to link to in each token.
enum
class
Focus
{
kFirst
,
// first character in token
kLast
,
// last character in token
};
// Translator to apply to the linked character index.
enum
class
Translator
{
kIdentity
,
// direct identity link
kReversed
,
// reverse-order link
};
// Returns the LinkedFeatureChannel.fml for the |focus|.
string
ChannelFml
(
Focus
focus
)
{
switch
(
focus
)
{
case
Focus
::
kFirst
:
return
"input.first-char-focus"
;
case
Focus
::
kLast
:
return
"input.last-char-focus"
;
}
}
// Returns the LinkedFeatureChannel.source_translator for the |translator|.
string
ChannelTranslator
(
Translator
translator
)
{
switch
(
translator
)
{
case
Translator
::
kIdentity
:
return
"identity"
;
case
Translator
::
kReversed
:
return
"reverse-char"
;
}
}
// Returns the |focus| byte index for the |token|. The returned index must be
// within the span of the |token|.
int32
GetFocusByte
(
Focus
focus
,
const
Token
&
token
)
{
switch
(
focus
)
{
case
Focus
::
kFirst
:
return
token
.
start
();
case
Focus
::
kLast
:
return
token
.
end
();
}
}
// Applies the |translator| to the character |index| w.r.t. the |last_index| and
// returns the result.
int32
Translate
(
Translator
translator
,
int32
last_index
,
int32
index
)
{
switch
(
translator
)
{
case
Translator
::
kIdentity
:
return
index
;
case
Translator
::
kReversed
:
return
last_index
-
index
;
}
}
// Translates links from tokens in the target layer to UTF-8 characters in the
// source layer. Templated on a |focus| and |translator| (see above).
template
<
Focus
focus
,
Translator
translator
>
class
SyntaxNetCharacterSequenceLinker
:
public
SequenceLinker
{
public:
// Implements SequenceLinker.
bool
Supports
(
const
LinkedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
)
const
override
;
tensorflow
::
Status
Initialize
(
const
LinkedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
)
override
;
tensorflow
::
Status
GetLinks
(
size_t
source_num_steps
,
InputBatchCache
*
input
,
std
::
vector
<
int32
>
*
links
)
const
override
;
};
template
<
Focus
focus
,
Translator
translator
>
bool
SyntaxNetCharacterSequenceLinker
<
focus
,
translator
>::
Supports
(
const
LinkedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
)
const
{
TransitionSystemTraits
traits
(
component_spec
);
return
channel
.
fml
()
==
ChannelFml
(
focus
)
&&
channel
.
source_translator
()
==
ChannelTranslator
(
translator
)
&&
component_spec
.
backend
().
registered_name
()
==
"SyntaxNetComponent"
&&
traits
.
is_sequential
&&
traits
.
is_token_scale
;
}
template
<
Focus
focus
,
Translator
translator
>
tensorflow
::
Status
SyntaxNetCharacterSequenceLinker
<
focus
,
translator
>::
Initialize
(
const
LinkedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
)
{
return
tensorflow
::
Status
::
OK
();
}
template
<
Focus
focus
,
Translator
translator
>
tensorflow
::
Status
SyntaxNetCharacterSequenceLinker
<
focus
,
translator
>::
GetLinks
(
size_t
source_num_steps
,
InputBatchCache
*
input
,
std
::
vector
<
int32
>
*
links
)
const
{
const
std
::
vector
<
SyntaxNetSentence
>
&
batch
=
*
input
->
GetAs
<
SentenceInputBatch
>
()
->
data
();
if
(
batch
.
size
()
!=
1
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Non-singleton batch: got "
,
batch
.
size
(),
" elements"
);
}
const
Sentence
&
sentence
=
*
batch
[
0
].
sentence
();
const
int32
num_tokens
=
sentence
.
token_size
();
links
->
resize
(
num_tokens
);
if
(
num_tokens
==
0
)
return
tensorflow
::
Status
::
OK
();
// Given the properties selected in Supports(), the number of source steps
// must match the number of UTF-8 characters. The last character index will
// be used in Translate().
const
int32
last_char_index
=
static_cast
<
int32
>
(
source_num_steps
)
-
1
;
// [start,end) byte range of the text spanned by the sentence tokens.
const
int32
start_byte
=
sentence
.
token
(
0
).
start
();
const
int32
end_byte
=
sentence
.
token
(
num_tokens
-
1
).
end
()
+
1
;
const
char
*
const
data
=
sentence
.
text
().
data
();
if
(
UniLib
::
IsTrailByte
(
data
[
start_byte
]))
{
return
tensorflow
::
errors
::
InvalidArgument
(
"First token starts in the middle of a UTF-8 character: "
,
sentence
.
token
(
0
).
ShortDebugString
());
}
// Current character index and its past-the-end byte in the sentence.
int32
char_index
=
0
;
int32
char_end_byte
=
start_byte
+
UniLib
::
OneCharLen
(
data
+
start_byte
);
// Current token index and its byte index.
int32
token_index
=
0
;
int32
token_byte
=
GetFocusByte
(
focus
,
sentence
.
token
(
0
));
// Scan through the characters and tokens. For each token, we assign it the
// character whose byte range contains its focus byte.
while
(
true
)
{
// If the character ends after the token, then the token must lie within the
// character, or we would have consumed the token in a previous iteration.
if
(
char_end_byte
>
token_byte
)
{
(
*
links
)[
token_index
]
=
Translate
(
translator
,
last_char_index
,
char_index
);
if
(
++
token_index
>=
num_tokens
)
break
;
token_byte
=
GetFocusByte
(
focus
,
sentence
.
token
(
token_index
));
}
else
if
(
char_end_byte
<
end_byte
)
{
++
char_index
;
char_end_byte
+=
UniLib
::
OneCharLen
(
data
+
char_end_byte
);
}
else
{
break
;
}
}
if
(
char_end_byte
>
end_byte
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Last token ends in the middle of a UTF-8 character: "
,
sentence
.
token
(
num_tokens
-
1
).
ShortDebugString
());
}
// Since GetFocusByte() always returns a byte index within the span of the
// token, the loop above must consume all tokens.
DCHECK_EQ
(
token_index
,
num_tokens
);
return
tensorflow
::
Status
::
OK
();
}
using
SyntaxNetFirstCharacterIdentitySequenceLinker
=
SyntaxNetCharacterSequenceLinker
<
Focus
::
kFirst
,
Translator
::
kIdentity
>
;
using
SyntaxNetFirstCharacterReversedSequenceLinker
=
SyntaxNetCharacterSequenceLinker
<
Focus
::
kFirst
,
Translator
::
kReversed
>
;
using
SyntaxNetLastCharacterIdentitySequenceLinker
=
SyntaxNetCharacterSequenceLinker
<
Focus
::
kLast
,
Translator
::
kIdentity
>
;
using
SyntaxNetLastCharacterReversedSequenceLinker
=
SyntaxNetCharacterSequenceLinker
<
Focus
::
kLast
,
Translator
::
kReversed
>
;
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER
(
SyntaxNetFirstCharacterIdentitySequenceLinker
);
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER
(
SyntaxNetFirstCharacterReversedSequenceLinker
);
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER
(
SyntaxNetLastCharacterIdentitySequenceLinker
);
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER
(
SyntaxNetLastCharacterReversedSequenceLinker
);
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/syntaxnet_character_sequence_linkers_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/sequence_linker.h"
#include "syntaxnet/base.h"
#include "syntaxnet/sentence.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
using
::
testing
::
ElementsAre
;
// Returns a ComponentSpec parsed from the |text|.
ComponentSpec
ParseSpec
(
const
string
&
text
)
{
ComponentSpec
component_spec
;
CHECK
(
TextFormat
::
ParseFromString
(
text
,
&
component_spec
));
return
component_spec
;
}
// Returns a ComponentSpec that some linker supports.
ComponentSpec
MakeSupportedSpec
()
{
return
ParseSpec
(
R"(
transition_system { registered_name:'shift-only' }
backend { registered_name:'SyntaxNetComponent' }
linked_feature { fml:'input.first-char-focus' source_translator:'identity' }
)"
);
}
// Returns a Sentence parsed from the |text|.
Sentence
ParseSentence
(
const
string
&
text
)
{
Sentence
sentence
;
CHECK
(
TextFormat
::
ParseFromString
(
text
,
&
sentence
));
return
sentence
;
}
// Returns a default sentence.
Sentence
MakeSentence
()
{
return
ParseSentence
(
R"(
text:'012345678901234567890123456789人1工神2经网¢络'
token { start:30 end:36 word:'人1工' }
token { start:37 end:43 word:'神2经' }
token { start:44 end:51 word:'网¢络' }
)"
);
}
// Number of UTF-8 characters in the default sentence.
constexpr
int
kNumChars
=
9
;
// Tests that the linker supports appropriate specs.
TEST
(
SyntaxNetCharacterSequenceLinkersTest
,
Supported
)
{
string
name
;
ComponentSpec
component_spec
=
MakeSupportedSpec
();
LinkedFeatureChannel
&
channel
=
*
component_spec
.
mutable_linked_feature
(
0
);
TF_ASSERT_OK
(
SequenceLinker
::
Select
(
channel
,
component_spec
,
&
name
));
EXPECT_EQ
(
name
,
"SyntaxNetFirstCharacterIdentitySequenceLinker"
);
channel
.
set_source_translator
(
"reverse-char"
);
TF_ASSERT_OK
(
SequenceLinker
::
Select
(
channel
,
component_spec
,
&
name
));
EXPECT_EQ
(
name
,
"SyntaxNetFirstCharacterReversedSequenceLinker"
);
channel
.
set_fml
(
"input.last-char-focus"
);
TF_ASSERT_OK
(
SequenceLinker
::
Select
(
channel
,
component_spec
,
&
name
));
EXPECT_EQ
(
name
,
"SyntaxNetLastCharacterReversedSequenceLinker"
);
channel
.
set_source_translator
(
"identity"
);
TF_ASSERT_OK
(
SequenceLinker
::
Select
(
channel
,
component_spec
,
&
name
));
EXPECT_EQ
(
name
,
"SyntaxNetLastCharacterIdentitySequenceLinker"
);
}
// Tests that the linker requires the right transition system.
TEST
(
SyntaxNetCharacterSequenceLinkersTest
,
WrongTransitionSystem
)
{
string
name
;
ComponentSpec
component_spec
=
MakeSupportedSpec
();
const
LinkedFeatureChannel
&
channel
=
component_spec
.
linked_feature
(
0
);
component_spec
.
mutable_backend
()
->
set_registered_name
(
"bad"
);
EXPECT_THAT
(
SequenceLinker
::
Select
(
channel
,
component_spec
,
&
name
),
test
::
IsErrorWithSubstr
(
"No SequenceLinker supports channel"
));
}
// Tests that the linker requires the right FML.
TEST
(
SyntaxNetCharacterSequenceLinkersTest
,
WrongFml
)
{
string
name
;
ComponentSpec
component_spec
=
MakeSupportedSpec
();
LinkedFeatureChannel
&
channel
=
*
component_spec
.
mutable_linked_feature
(
0
);
channel
.
set_fml
(
"bad"
);
EXPECT_THAT
(
SequenceLinker
::
Select
(
channel
,
component_spec
,
&
name
),
test
::
IsErrorWithSubstr
(
"No SequenceLinker supports channel"
));
}
// Tests that the linker requires the right translator.
TEST
(
SyntaxNetCharacterSequenceLinkersTest
,
WrongTranslator
)
{
string
name
;
ComponentSpec
component_spec
=
MakeSupportedSpec
();
LinkedFeatureChannel
&
channel
=
*
component_spec
.
mutable_linked_feature
(
0
);
channel
.
set_source_translator
(
"bad"
);
EXPECT_THAT
(
SequenceLinker
::
Select
(
channel
,
component_spec
,
&
name
),
test
::
IsErrorWithSubstr
(
"No SequenceLinker supports channel"
));
}
// Tests that the linker requires the right backend.
TEST
(
SyntaxNetCharacterSequenceLinkersTest
,
WrongBackend
)
{
string
name
;
ComponentSpec
component_spec
=
MakeSupportedSpec
();
const
LinkedFeatureChannel
&
channel
=
component_spec
.
linked_feature
(
0
);
component_spec
.
mutable_backend
()
->
set_registered_name
(
"bad"
);
EXPECT_THAT
(
SequenceLinker
::
Select
(
channel
,
component_spec
,
&
name
),
test
::
IsErrorWithSubstr
(
"No SequenceLinker supports channel"
));
}
// Rig for testing GetLinks().
class
SyntaxNetCharacterSequenceLinkersGetLinksTest
:
public
::
testing
::
Test
{
protected:
void
SetUp
()
override
{
// Initialize() doesn't look at the channel or spec, so use empty protos.
const
ComponentSpec
component_spec
;
const
LinkedFeatureChannel
channel
;
TF_ASSERT_OK
(
SequenceLinker
::
New
(
"SyntaxNetFirstCharacterIdentitySequenceLinker"
,
channel
,
component_spec
,
&
first_identity_
));
TF_ASSERT_OK
(
SequenceLinker
::
New
(
"SyntaxNetFirstCharacterReversedSequenceLinker"
,
channel
,
component_spec
,
&
first_reversed_
));
TF_ASSERT_OK
(
SequenceLinker
::
New
(
"SyntaxNetLastCharacterIdentitySequenceLinker"
,
channel
,
component_spec
,
&
last_identity_
));
TF_ASSERT_OK
(
SequenceLinker
::
New
(
"SyntaxNetLastCharacterReversedSequenceLinker"
,
channel
,
component_spec
,
&
last_reversed_
));
}
// Linkers in all four configurations.
std
::
unique_ptr
<
SequenceLinker
>
first_identity_
;
std
::
unique_ptr
<
SequenceLinker
>
first_reversed_
;
std
::
unique_ptr
<
SequenceLinker
>
last_identity_
;
std
::
unique_ptr
<
SequenceLinker
>
last_reversed_
;
};
// Tests that the linkers can extract links from the default sentence.
TEST_F
(
SyntaxNetCharacterSequenceLinkersGetLinksTest
,
DefaultSentence
)
{
const
Sentence
sentence
=
MakeSentence
();
InputBatchCache
input
(
sentence
.
SerializeAsString
());
std
::
vector
<
int32
>
links
=
{
123
,
456
,
789
};
// gets overwritten
TF_ASSERT_OK
(
first_identity_
->
GetLinks
(
kNumChars
,
&
input
,
&
links
));
EXPECT_THAT
(
links
,
ElementsAre
(
0
,
3
,
6
));
TF_ASSERT_OK
(
first_reversed_
->
GetLinks
(
kNumChars
,
&
input
,
&
links
));
EXPECT_THAT
(
links
,
ElementsAre
(
8
,
5
,
2
));
TF_ASSERT_OK
(
last_identity_
->
GetLinks
(
kNumChars
,
&
input
,
&
links
));
EXPECT_THAT
(
links
,
ElementsAre
(
2
,
5
,
8
));
TF_ASSERT_OK
(
last_reversed_
->
GetLinks
(
kNumChars
,
&
input
,
&
links
));
EXPECT_THAT
(
links
,
ElementsAre
(
6
,
3
,
0
));
}
// Tests that the linkers can handle an empty sentence.
TEST_F
(
SyntaxNetCharacterSequenceLinkersGetLinksTest
,
EmptySentence
)
{
const
Sentence
sentence
;
InputBatchCache
input
(
sentence
.
SerializeAsString
());
std
::
vector
<
int32
>
links
;
TF_ASSERT_OK
(
first_identity_
->
GetLinks
(
kNumChars
,
&
input
,
&
links
));
TF_ASSERT_OK
(
first_reversed_
->
GetLinks
(
kNumChars
,
&
input
,
&
links
));
TF_ASSERT_OK
(
last_identity_
->
GetLinks
(
kNumChars
,
&
input
,
&
links
));
TF_ASSERT_OK
(
last_reversed_
->
GetLinks
(
kNumChars
,
&
input
,
&
links
));
}
// Tests that the linkers fail if the batch is not a singleton.
TEST_F
(
SyntaxNetCharacterSequenceLinkersGetLinksTest
,
NonSingleton
)
{
const
Sentence
sentence
=
MakeSentence
();
const
std
::
vector
<
string
>
data
=
{
sentence
.
SerializeAsString
(),
sentence
.
SerializeAsString
()};
InputBatchCache
input
(
data
);
std
::
vector
<
int32
>
links
;
EXPECT_THAT
(
first_identity_
->
GetLinks
(
kNumChars
,
&
input
,
&
links
),
test
::
IsErrorWithSubstr
(
"Non-singleton batch: got 2 elements"
));
EXPECT_THAT
(
first_reversed_
->
GetLinks
(
kNumChars
,
&
input
,
&
links
),
test
::
IsErrorWithSubstr
(
"Non-singleton batch: got 2 elements"
));
EXPECT_THAT
(
last_identity_
->
GetLinks
(
kNumChars
,
&
input
,
&
links
),
test
::
IsErrorWithSubstr
(
"Non-singleton batch: got 2 elements"
));
EXPECT_THAT
(
last_reversed_
->
GetLinks
(
kNumChars
,
&
input
,
&
links
),
test
::
IsErrorWithSubstr
(
"Non-singleton batch: got 2 elements"
));
}
// Tests that the linkers fail if the first token starts in the middle of a
// UTF-8 character.
TEST_F
(
SyntaxNetCharacterSequenceLinkersGetLinksTest
,
FirstTokenStartsWrong
)
{
Sentence
sentence
=
MakeSentence
();
sentence
.
mutable_token
(
0
)
->
set_start
(
31
);
InputBatchCache
input
(
sentence
.
SerializeAsString
());
std
::
vector
<
int32
>
links
;
EXPECT_THAT
(
first_identity_
->
GetLinks
(
kNumChars
,
&
input
,
&
links
),
test
::
IsErrorWithSubstr
(
"First token starts in the middle of a UTF-8 character"
));
EXPECT_THAT
(
first_reversed_
->
GetLinks
(
kNumChars
,
&
input
,
&
links
),
test
::
IsErrorWithSubstr
(
"First token starts in the middle of a UTF-8 character"
));
EXPECT_THAT
(
last_identity_
->
GetLinks
(
kNumChars
,
&
input
,
&
links
),
test
::
IsErrorWithSubstr
(
"First token starts in the middle of a UTF-8 character"
));
EXPECT_THAT
(
last_reversed_
->
GetLinks
(
kNumChars
,
&
input
,
&
links
),
test
::
IsErrorWithSubstr
(
"First token starts in the middle of a UTF-8 character"
));
}
// Tests that the linkers fail if the last token ends in the middle of a UTF-8
// character.
TEST_F
(
SyntaxNetCharacterSequenceLinkersGetLinksTest
,
LastTokenEndsWrong
)
{
Sentence
sentence
=
MakeSentence
();
sentence
.
mutable_token
(
2
)
->
set_end
(
45
);
InputBatchCache
input
(
sentence
.
SerializeAsString
());
std
::
vector
<
int32
>
links
;
EXPECT_THAT
(
first_identity_
->
GetLinks
(
kNumChars
,
&
input
,
&
links
),
test
::
IsErrorWithSubstr
(
"Last token ends in the middle of a UTF-8 character"
));
EXPECT_THAT
(
first_reversed_
->
GetLinks
(
kNumChars
,
&
input
,
&
links
),
test
::
IsErrorWithSubstr
(
"Last token ends in the middle of a UTF-8 character"
));
EXPECT_THAT
(
last_identity_
->
GetLinks
(
kNumChars
,
&
input
,
&
links
),
test
::
IsErrorWithSubstr
(
"Last token ends in the middle of a UTF-8 character"
));
EXPECT_THAT
(
last_reversed_
->
GetLinks
(
kNumChars
,
&
input
,
&
links
),
test
::
IsErrorWithSubstr
(
"Last token ends in the middle of a UTF-8 character"
));
}
// Tests that the linkers can tolerate a sentence where the interior token byte
// offsets are wrong.
TEST_F
(
SyntaxNetCharacterSequenceLinkersGetLinksTest
,
InteriorTokenBoundariesSlightlyWrong
)
{
Sentence
sentence
=
MakeSentence
();
sentence
.
mutable_token
(
0
)
->
set_end
(
35
);
sentence
.
mutable_token
(
1
)
->
set_start
(
38
);
sentence
.
mutable_token
(
1
)
->
set_end
(
42
);
sentence
.
mutable_token
(
2
)
->
set_start
(
45
);
InputBatchCache
input
(
sentence
.
SerializeAsString
());
std
::
vector
<
int32
>
links
;
// The results should be the same as in the default sentence.
TF_ASSERT_OK
(
first_identity_
->
GetLinks
(
kNumChars
,
&
input
,
&
links
));
EXPECT_THAT
(
links
,
ElementsAre
(
0
,
3
,
6
));
TF_ASSERT_OK
(
first_reversed_
->
GetLinks
(
kNumChars
,
&
input
,
&
links
));
EXPECT_THAT
(
links
,
ElementsAre
(
8
,
5
,
2
));
TF_ASSERT_OK
(
last_identity_
->
GetLinks
(
kNumChars
,
&
input
,
&
links
));
EXPECT_THAT
(
links
,
ElementsAre
(
2
,
5
,
8
));
TF_ASSERT_OK
(
last_reversed_
->
GetLinks
(
kNumChars
,
&
input
,
&
links
));
EXPECT_THAT
(
links
,
ElementsAre
(
6
,
3
,
0
));
}
// As above, but places the token boundaries even further off.
TEST_F
(
SyntaxNetCharacterSequenceLinkersGetLinksTest
,
InteriorTokenBoundariesMostlyWrong
)
{
Sentence
sentence
=
MakeSentence
();
sentence
.
mutable_token
(
0
)
->
set_end
(
34
);
sentence
.
mutable_token
(
1
)
->
set_start
(
39
);
sentence
.
mutable_token
(
1
)
->
set_end
(
41
);
sentence
.
mutable_token
(
2
)
->
set_start
(
46
);
InputBatchCache
input
(
sentence
.
SerializeAsString
());
std
::
vector
<
int32
>
links
;
// The results should be the same as in the default sentence.
TF_ASSERT_OK
(
first_identity_
->
GetLinks
(
kNumChars
,
&
input
,
&
links
));
EXPECT_THAT
(
links
,
ElementsAre
(
0
,
3
,
6
));
TF_ASSERT_OK
(
first_reversed_
->
GetLinks
(
kNumChars
,
&
input
,
&
links
));
EXPECT_THAT
(
links
,
ElementsAre
(
8
,
5
,
2
));
TF_ASSERT_OK
(
last_identity_
->
GetLinks
(
kNumChars
,
&
input
,
&
links
));
EXPECT_THAT
(
links
,
ElementsAre
(
2
,
5
,
8
));
TF_ASSERT_OK
(
last_reversed_
->
GetLinks
(
kNumChars
,
&
input
,
&
links
));
EXPECT_THAT
(
links
,
ElementsAre
(
6
,
3
,
0
));
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/syntaxnet_head_selection_component.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <string>
#include <vector>
#include "dragnn/core/compute_session.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/io/sentence_input_batch.h"
#include "dragnn/io/syntaxnet_sentence.h"
#include "dragnn/protos/trace.pb.h"
#include "dragnn/runtime/head_selection_component_base.h"
#include "dragnn/runtime/session_state.h"
#include "syntaxnet/base.h"
#include "syntaxnet/sentence.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Selects heads for SyntaxNetComponent batches.
class
SyntaxNetHeadSelectionComponent
:
public
HeadSelectionComponentBase
{
public:
SyntaxNetHeadSelectionComponent
()
:
HeadSelectionComponentBase
(
"SyntaxNetHeadSelectionComponent"
,
"SyntaxNetComponent"
)
{}
// Implements Component.
tensorflow
::
Status
Evaluate
(
SessionState
*
session_state
,
ComputeSession
*
compute_session
,
ComponentTrace
*
component_trace
)
const
override
;
};
tensorflow
::
Status
SyntaxNetHeadSelectionComponent
::
Evaluate
(
SessionState
*
session_state
,
ComputeSession
*
compute_session
,
ComponentTrace
*
component_trace
)
const
{
InputBatchCache
*
input
=
compute_session
->
GetInputBatchCache
();
if
(
input
==
nullptr
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Null input batch"
);
}
const
std
::
vector
<
SyntaxNetSentence
>
&
data
=
*
input
->
GetAs
<
SentenceInputBatch
>
()
->
data
();
if
(
data
.
size
()
!=
1
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Non-singleton batch: got "
,
data
.
size
(),
" elements"
);
}
const
std
::
vector
<
int
>
&
heads
=
ComputeHeads
(
session_state
);
Sentence
*
sentence
=
data
[
0
].
sentence
();
if
(
heads
.
size
()
!=
sentence
->
token_size
())
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Sentence size mismatch: expected "
,
heads
.
size
(),
" tokens but got "
,
sentence
->
token_size
());
}
int
token_index
=
0
;
for
(
const
int
head
:
heads
)
{
Token
*
token
=
sentence
->
mutable_token
(
token_index
++
);
if
(
head
==
-
1
)
{
token
->
clear_head
();
}
else
{
token
->
set_head
(
head
);
}
}
return
tensorflow
::
Status
::
OK
();
}
DRAGNN_RUNTIME_REGISTER_COMPONENT
(
SyntaxNetHeadSelectionComponent
);
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/syntaxnet_head_selection_component_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/compute_session.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/io/sentence_input_batch.h"
#include "dragnn/io/syntaxnet_sentence.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/sentence.pb.h"
#include <gmock/gmock.h>
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
using
::
testing
::
Return
;
constexpr
char
kPreviousComponentName
[]
=
"previous_component"
;
constexpr
char
kAdjacencyLayerName
[]
=
"adjacency_layer"
;
// Returns a ComponentSpec that works with the head selection component.
ComponentSpec
MakeGoodSpec
()
{
ComponentSpec
component_spec
;
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"SyntaxNetHeadSelectionComponent"
);
component_spec
.
mutable_backend
()
->
set_registered_name
(
"SyntaxNetComponent"
);
component_spec
.
mutable_transition_system
()
->
set_registered_name
(
"heads"
);
component_spec
.
mutable_network_unit
()
->
set_registered_name
(
"IdentityNetwork"
);
LinkedFeatureChannel
*
link
=
component_spec
.
add_linked_feature
();
link
->
set_source_component
(
kPreviousComponentName
);
link
->
set_source_layer
(
kAdjacencyLayerName
);
return
component_spec
;
}
// Returns a sentence containing |num_tokens| tokens. All heads are set to
// self-loops, which are normally invalid, to ensure that the head selector
// touches all tokens.
Sentence
MakeSentence
(
int
num_tokens
)
{
Sentence
sentence
;
for
(
int
i
=
0
;
i
<
num_tokens
;
++
i
)
{
Token
*
token
=
sentence
.
add_token
();
token
->
set_start
(
0
);
// never used; set because required field
token
->
set_end
(
0
);
// never used; set because required field
token
->
set_word
(
"foo"
);
// never used; set because required field
token
->
set_head
(
i
);
}
return
sentence
;
}
class
SyntaxNetHeadSelectionComponentTest
:
public
NetworkTestBase
{
protected:
// Initializes a parser head selection component from the |component_spec|,
// feeds it the |adjacency| matrix, and applies the resulting heads to the
// |sentence|. Returs non-OK on error.
tensorflow
::
Status
Run
(
const
ComponentSpec
&
component_spec
,
const
std
::
vector
<
std
::
vector
<
float
>>
&
adjacency
,
Sentence
*
sentence
)
{
AddComponent
(
kPreviousComponentName
);
AddPairwiseLayer
(
kAdjacencyLayerName
,
1
);
std
::
unique_ptr
<
Component
>
component
;
TF_RETURN_IF_ERROR
(
Component
::
CreateOrError
(
"SyntaxNetHeadSelectionComponent"
,
&
component
));
TF_RETURN_IF_ERROR
(
component
->
Initialize
(
component_spec
,
&
variable_store_
,
&
network_state_manager_
,
&
extension_manager_
));
network_states_
.
Reset
(
&
network_state_manager_
);
const
int
num_steps
=
adjacency
.
size
();
StartComponent
(
num_steps
);
MutableMatrix
<
float
>
adjacency_layer
=
GetPairwiseLayer
(
kPreviousComponentName
,
kAdjacencyLayerName
);
for
(
size_t
target
=
0
;
target
<
num_steps
;
++
target
)
{
for
(
size_t
source
=
0
;
source
<
num_steps
;
++
source
)
{
adjacency_layer
.
row
(
target
)[
source
]
=
adjacency
[
target
][
source
];
}
}
string
data
;
CHECK
(
sentence
->
SerializeToString
(
&
data
));
InputBatchCache
input
(
data
);
EXPECT_CALL
(
compute_session_
,
GetInputBatchCache
())
.
WillRepeatedly
(
Return
(
&
input
));
session_state_
.
extensions
.
Reset
(
&
extension_manager_
);
TF_RETURN_IF_ERROR
(
component
->
Evaluate
(
&
session_state_
,
&
compute_session_
,
nullptr
));
CHECK
(
sentence
->
ParseFromString
(
input
.
SerializedData
()[
0
]));
return
tensorflow
::
Status
::
OK
();
}
};
// Tests the head selector on a single-token input.
TEST_F
(
SyntaxNetHeadSelectionComponentTest
,
ParseOneToken
)
{
const
std
::
vector
<
std
::
vector
<
float
>>
adjacency
=
{{
0.0
}};
Sentence
sentence
=
MakeSentence
(
1
);
TF_ASSERT_OK
(
Run
(
MakeGoodSpec
(),
adjacency
,
&
sentence
));
EXPECT_FALSE
(
sentence
.
token
(
0
).
has_head
());
}
// Tests the head selector on a two-token input.
TEST_F
(
SyntaxNetHeadSelectionComponentTest
,
ParseTwoTokens
)
{
// This adjacency matrix forms a cycle, not a tree, but it doesn't matter
// since the head selector is unstructured.
const
std
::
vector
<
std
::
vector
<
float
>>
adjacency
=
{{
0.0
,
1.0
},
//
{
1.0
,
0.0
}};
Sentence
sentence
=
MakeSentence
(
2
);
TF_ASSERT_OK
(
Run
(
MakeGoodSpec
(),
adjacency
,
&
sentence
));
EXPECT_EQ
(
sentence
.
token
(
0
).
head
(),
1
);
EXPECT_EQ
(
sentence
.
token
(
1
).
head
(),
0
);
}
// Tests the head selector on a three-token input.
TEST_F
(
SyntaxNetHeadSelectionComponentTest
,
ParseThreeTokens
)
{
// This adjacency matrix forms a left-headed chain.
const
std
::
vector
<
std
::
vector
<
float
>>
adjacency
=
{{
1.0
,
0.0
,
0.0
},
//
{
1.0
,
0.0
,
0.0
},
//
{
0.0
,
1.0
,
0.0
}};
Sentence
sentence
=
MakeSentence
(
3
);
TF_ASSERT_OK
(
Run
(
MakeGoodSpec
(),
adjacency
,
&
sentence
));
EXPECT_FALSE
(
sentence
.
token
(
0
).
has_head
());
EXPECT_EQ
(
sentence
.
token
(
1
).
head
(),
0
);
EXPECT_EQ
(
sentence
.
token
(
2
).
head
(),
1
);
}
// Tests the head selector on a four-token input.
TEST_F
(
SyntaxNetHeadSelectionComponentTest
,
ParseFourTokens
)
{
// This adjacency matrix forms a right-headed chain.
const
std
::
vector
<
std
::
vector
<
float
>>
adjacency
=
{{
0.0
,
1.0
,
0.0
,
0.0
},
//
{
0.0
,
0.0
,
1.0
,
0.0
},
//
{
0.0
,
0.0
,
0.0
,
1.0
},
//
{
0.0
,
0.0
,
0.0
,
1.0
}};
Sentence
sentence
=
MakeSentence
(
4
);
TF_ASSERT_OK
(
Run
(
MakeGoodSpec
(),
adjacency
,
&
sentence
));
EXPECT_EQ
(
sentence
.
token
(
0
).
head
(),
1
);
EXPECT_EQ
(
sentence
.
token
(
1
).
head
(),
2
);
EXPECT_EQ
(
sentence
.
token
(
2
).
head
(),
3
);
EXPECT_FALSE
(
sentence
.
token
(
3
).
has_head
());
}
// Tests that the component supports the good spec.
TEST_F
(
SyntaxNetHeadSelectionComponentTest
,
Supported
)
{
const
ComponentSpec
component_spec
=
MakeGoodSpec
();
string
name
;
TF_ASSERT_OK
(
Component
::
Select
(
component_spec
,
&
name
));
EXPECT_EQ
(
name
,
"SyntaxNetHeadSelectionComponent"
);
}
// Tests that the component requires the proper backend.
TEST_F
(
SyntaxNetHeadSelectionComponentTest
,
WrongComponentBuilder
)
{
ComponentSpec
component_spec
=
MakeGoodSpec
();
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"bad"
);
string
name
;
EXPECT_THAT
(
Component
::
Select
(
component_spec
,
&
name
),
test
::
IsErrorWithSubstr
(
"Could not find a best spec for component"
));
}
// Tests that the component requires the proper backend.
TEST_F
(
SyntaxNetHeadSelectionComponentTest
,
WrongBackend
)
{
ComponentSpec
component_spec
=
MakeGoodSpec
();
component_spec
.
mutable_backend
()
->
set_registered_name
(
"bad"
);
string
name
;
EXPECT_THAT
(
Component
::
Select
(
component_spec
,
&
name
),
test
::
IsErrorWithSubstr
(
"Could not find a best spec for component"
));
}
// Tests that Evaluate() fails if the batch is null.
TEST_F
(
SyntaxNetHeadSelectionComponentTest
,
NullBatch
)
{
std
::
unique_ptr
<
Component
>
component
;
TF_ASSERT_OK
(
Component
::
CreateOrError
(
"SyntaxNetHeadSelectionComponent"
,
&
component
));
EXPECT_CALL
(
compute_session_
,
GetInputBatchCache
())
.
WillRepeatedly
(
Return
(
nullptr
));
EXPECT_THAT
(
component
->
Evaluate
(
&
session_state_
,
&
compute_session_
,
nullptr
),
test
::
IsErrorWithSubstr
(
"Null input batch"
));
}
// Tests that Evaluate() fails if the batch is the wrong size.
TEST_F
(
SyntaxNetHeadSelectionComponentTest
,
WrongBatchSize
)
{
std
::
unique_ptr
<
Component
>
component
;
TF_ASSERT_OK
(
Component
::
CreateOrError
(
"SyntaxNetHeadSelectionComponent"
,
&
component
));
InputBatchCache
input
({
MakeSentence
(
1
).
SerializeAsString
(),
MakeSentence
(
2
).
SerializeAsString
(),
MakeSentence
(
3
).
SerializeAsString
(),
MakeSentence
(
4
).
SerializeAsString
()});
EXPECT_CALL
(
compute_session_
,
GetInputBatchCache
())
.
WillRepeatedly
(
Return
(
&
input
));
EXPECT_THAT
(
component
->
Evaluate
(
&
session_state_
,
&
compute_session_
,
nullptr
),
test
::
IsErrorWithSubstr
(
"Non-singleton batch: got 4 elements"
));
}
// Tests that Evaluate() fails if the adjacency matrix and sentence disagree on
// the number of tokens.
TEST_F
(
SyntaxNetHeadSelectionComponentTest
,
WrongNumTokens
)
{
const
std
::
vector
<
std
::
vector
<
float
>>
adjacency
=
{{
1.0
,
0.0
,
0.0
,
0.0
},
//
{
0.0
,
1.0
,
0.0
,
0.0
},
//
{
0.0
,
0.0
,
1.0
,
0.0
},
//
{
0.0
,
0.0
,
0.0
,
1.0
}};
// 4-token adjacency matrix with 3-token sentence.
Sentence
sentence
=
MakeSentence
(
3
);
EXPECT_THAT
(
Run
(
MakeGoodSpec
(),
adjacency
,
&
sentence
),
test
::
IsErrorWithSubstr
(
"Sentence size mismatch: expected 4 tokens but got 3"
));
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/syntaxnet_mst_solver_component.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <string>
#include <vector>
#include "dragnn/core/compute_session.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/io/sentence_input_batch.h"
#include "dragnn/io/syntaxnet_sentence.h"
#include "dragnn/protos/trace.pb.h"
#include "dragnn/runtime/mst_solver_component_base.h"
#include "dragnn/runtime/session_state.h"
#include "syntaxnet/base.h"
#include "syntaxnet/sentence.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Selects heads for SyntaxNetComponent batches.
class
SyntaxNetMstSolverComponent
:
public
MstSolverComponentBase
{
public:
SyntaxNetMstSolverComponent
()
:
MstSolverComponentBase
(
"SyntaxNetMstSolverComponent"
,
"SyntaxNetComponent"
)
{}
// Implements Component.
tensorflow
::
Status
Evaluate
(
SessionState
*
session_state
,
ComputeSession
*
compute_session
,
ComponentTrace
*
component_trace
)
const
override
;
};
tensorflow
::
Status
SyntaxNetMstSolverComponent
::
Evaluate
(
SessionState
*
session_state
,
ComputeSession
*
compute_session
,
ComponentTrace
*
component_trace
)
const
{
InputBatchCache
*
input
=
compute_session
->
GetInputBatchCache
();
if
(
input
==
nullptr
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Null input batch"
);
}
const
std
::
vector
<
SyntaxNetSentence
>
&
data
=
*
input
->
GetAs
<
SentenceInputBatch
>
()
->
data
();
if
(
data
.
size
()
!=
1
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Non-singleton batch: got "
,
data
.
size
(),
" elements"
);
}
tensorflow
::
gtl
::
ArraySlice
<
Index
>
heads
;
TF_RETURN_IF_ERROR
(
ComputeHeads
(
session_state
,
&
heads
));
Sentence
*
sentence
=
data
[
0
].
sentence
();
if
(
heads
.
size
()
!=
sentence
->
token_size
())
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Sentence size mismatch: expected "
,
heads
.
size
(),
" tokens but got "
,
sentence
->
token_size
());
}
const
int
num_tokens
=
heads
.
size
();
for
(
int
modifier
=
0
;
modifier
<
num_tokens
;
++
modifier
)
{
Token
*
token
=
sentence
->
mutable_token
(
modifier
);
const
int
head
=
heads
[
modifier
];
if
(
head
==
modifier
)
{
token
->
clear_head
();
}
else
{
token
->
set_head
(
head
);
}
}
return
tensorflow
::
Status
::
OK
();
}
DRAGNN_RUNTIME_REGISTER_COMPONENT
(
SyntaxNetMstSolverComponent
);
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/syntaxnet_mst_solver_component_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/compute_session.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/io/sentence_input_batch.h"
#include "dragnn/io/syntaxnet_sentence.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/sentence.pb.h"
#include <gmock/gmock.h>
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
using
::
testing
::
Return
;
constexpr
char
kPreviousComponentName
[]
=
"previous_component"
;
constexpr
char
kAdjacencyLayerName
[]
=
"adjacency_layer"
;
// Returns a ComponentSpec that works with the head selection component.
ComponentSpec
MakeGoodSpec
()
{
ComponentSpec
component_spec
;
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"SyntaxNetMstSolverComponent"
);
component_spec
.
mutable_backend
()
->
set_registered_name
(
"SyntaxNetComponent"
);
component_spec
.
mutable_transition_system
()
->
set_registered_name
(
"heads"
);
component_spec
.
mutable_network_unit
()
->
set_registered_name
(
"some.path.to.MstSolverNetwork"
);
LinkedFeatureChannel
*
link
=
component_spec
.
add_linked_feature
();
link
->
set_source_component
(
kPreviousComponentName
);
link
->
set_source_layer
(
kAdjacencyLayerName
);
return
component_spec
;
}
// Returns a sentence containing |num_tokens| tokens. All heads are set to
// self-loops, which are normally invalid, to ensure that the head selector
// touches all tokens.
Sentence
MakeSentence
(
int
num_tokens
)
{
Sentence
sentence
;
for
(
int
i
=
0
;
i
<
num_tokens
;
++
i
)
{
Token
*
token
=
sentence
.
add_token
();
token
->
set_start
(
0
);
// never used; set because required field
token
->
set_end
(
0
);
// never used; set because required field
token
->
set_word
(
"foo"
);
// never used; set because required field
token
->
set_head
(
i
);
}
return
sentence
;
}
class
SyntaxNetMstSolverComponentTest
:
public
NetworkTestBase
{
protected:
// Initializes a parser head selection component from the |component_spec|,
// feeds it the |adjacency| matrix, and applies the resulting heads to the
// |sentence|. Returs non-OK on error.
tensorflow
::
Status
Run
(
const
ComponentSpec
&
component_spec
,
const
std
::
vector
<
std
::
vector
<
float
>>
&
adjacency
,
Sentence
*
sentence
)
{
AddComponent
(
kPreviousComponentName
);
AddPairwiseLayer
(
kAdjacencyLayerName
,
1
);
std
::
unique_ptr
<
Component
>
component
;
TF_RETURN_IF_ERROR
(
Component
::
CreateOrError
(
"SyntaxNetMstSolverComponent"
,
&
component
));
TF_RETURN_IF_ERROR
(
component
->
Initialize
(
component_spec
,
&
variable_store_
,
&
network_state_manager_
,
&
extension_manager_
));
network_states_
.
Reset
(
&
network_state_manager_
);
const
int
num_steps
=
adjacency
.
size
();
StartComponent
(
num_steps
);
MutableMatrix
<
float
>
adjacency_layer
=
GetPairwiseLayer
(
kPreviousComponentName
,
kAdjacencyLayerName
);
for
(
size_t
target
=
0
;
target
<
num_steps
;
++
target
)
{
for
(
size_t
source
=
0
;
source
<
num_steps
;
++
source
)
{
adjacency_layer
.
row
(
target
)[
source
]
=
adjacency
[
target
][
source
];
}
}
string
data
;
CHECK
(
sentence
->
SerializeToString
(
&
data
));
InputBatchCache
input
(
data
);
EXPECT_CALL
(
compute_session_
,
GetInputBatchCache
())
.
WillRepeatedly
(
Return
(
&
input
));
session_state_
.
extensions
.
Reset
(
&
extension_manager_
);
TF_RETURN_IF_ERROR
(
component
->
Evaluate
(
&
session_state_
,
&
compute_session_
,
nullptr
));
CHECK
(
sentence
->
ParseFromString
(
input
.
SerializedData
()[
0
]));
return
tensorflow
::
Status
::
OK
();
}
};
// Tests the head selector on a single-token input.
TEST_F
(
SyntaxNetMstSolverComponentTest
,
ParseOneToken
)
{
const
std
::
vector
<
std
::
vector
<
float
>>
adjacency
=
{{
0.0
}};
Sentence
sentence
=
MakeSentence
(
1
);
TF_ASSERT_OK
(
Run
(
MakeGoodSpec
(),
adjacency
,
&
sentence
));
EXPECT_FALSE
(
sentence
.
token
(
0
).
has_head
());
}
// Tests the head selector on a two-token input.
TEST_F
(
SyntaxNetMstSolverComponentTest
,
ParseTwoTokens
)
{
const
std
::
vector
<
std
::
vector
<
float
>>
adjacency
=
{{
0.0
,
1.0
},
//
{
0.9
,
1.0
}};
Sentence
sentence
=
MakeSentence
(
2
);
TF_ASSERT_OK
(
Run
(
MakeGoodSpec
(),
adjacency
,
&
sentence
));
EXPECT_EQ
(
sentence
.
token
(
0
).
head
(),
1
);
EXPECT_EQ
(
sentence
.
token
(
1
).
head
(),
-
1
);
}
// Tests the head selector on a three-token input.
TEST_F
(
SyntaxNetMstSolverComponentTest
,
ParseThreeTokens
)
{
// This adjacency matrix forms a left-headed chain.
const
std
::
vector
<
std
::
vector
<
float
>>
adjacency
=
{{
1.0
,
0.0
,
0.0
},
//
{
1.0
,
0.0
,
0.0
},
//
{
0.0
,
1.0
,
0.0
}};
Sentence
sentence
=
MakeSentence
(
3
);
TF_ASSERT_OK
(
Run
(
MakeGoodSpec
(),
adjacency
,
&
sentence
));
EXPECT_FALSE
(
sentence
.
token
(
0
).
has_head
());
EXPECT_EQ
(
sentence
.
token
(
1
).
head
(),
0
);
EXPECT_EQ
(
sentence
.
token
(
2
).
head
(),
1
);
}
// Tests the head selector on a four-token input.
TEST_F
(
SyntaxNetMstSolverComponentTest
,
ParseFourTokens
)
{
// This adjacency matrix forms a right-headed chain.
const
std
::
vector
<
std
::
vector
<
float
>>
adjacency
=
{{
0.0
,
1.0
,
0.0
,
0.0
},
//
{
0.0
,
0.0
,
1.0
,
0.0
},
//
{
0.0
,
0.0
,
0.0
,
1.0
},
//
{
0.0
,
0.0
,
0.0
,
1.0
}};
Sentence
sentence
=
MakeSentence
(
4
);
TF_ASSERT_OK
(
Run
(
MakeGoodSpec
(),
adjacency
,
&
sentence
));
EXPECT_EQ
(
sentence
.
token
(
0
).
head
(),
1
);
EXPECT_EQ
(
sentence
.
token
(
1
).
head
(),
2
);
EXPECT_EQ
(
sentence
.
token
(
2
).
head
(),
3
);
EXPECT_FALSE
(
sentence
.
token
(
3
).
has_head
());
}
// Tests that the component supports the good spec.
TEST_F
(
SyntaxNetMstSolverComponentTest
,
Supported
)
{
const
ComponentSpec
component_spec
=
MakeGoodSpec
();
string
name
;
TF_ASSERT_OK
(
Component
::
Select
(
component_spec
,
&
name
));
EXPECT_EQ
(
name
,
"SyntaxNetMstSolverComponent"
);
}
// Tests that the component requires the proper backend.
TEST_F
(
SyntaxNetMstSolverComponentTest
,
WrongComponentBuilder
)
{
ComponentSpec
component_spec
=
MakeGoodSpec
();
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"bad"
);
string
name
;
EXPECT_THAT
(
Component
::
Select
(
component_spec
,
&
name
),
test
::
IsErrorWithSubstr
(
"Could not find a best spec for component"
));
}
// Tests that the component requires the proper backend.
TEST_F
(
SyntaxNetMstSolverComponentTest
,
WrongBackend
)
{
ComponentSpec
component_spec
=
MakeGoodSpec
();
component_spec
.
mutable_backend
()
->
set_registered_name
(
"bad"
);
string
name
;
EXPECT_THAT
(
Component
::
Select
(
component_spec
,
&
name
),
test
::
IsErrorWithSubstr
(
"Could not find a best spec for component"
));
}
// Tests that Evaluate() fails if the batch is null.
TEST_F
(
SyntaxNetMstSolverComponentTest
,
NullBatch
)
{
std
::
unique_ptr
<
Component
>
component
;
TF_ASSERT_OK
(
Component
::
CreateOrError
(
"SyntaxNetMstSolverComponent"
,
&
component
));
EXPECT_CALL
(
compute_session_
,
GetInputBatchCache
())
.
WillRepeatedly
(
Return
(
nullptr
));
EXPECT_THAT
(
component
->
Evaluate
(
&
session_state_
,
&
compute_session_
,
nullptr
),
test
::
IsErrorWithSubstr
(
"Null input batch"
));
}
// Tests that Evaluate() fails if the batch is the wrong size.
TEST_F
(
SyntaxNetMstSolverComponentTest
,
WrongBatchSize
)
{
std
::
unique_ptr
<
Component
>
component
;
TF_ASSERT_OK
(
Component
::
CreateOrError
(
"SyntaxNetMstSolverComponent"
,
&
component
));
InputBatchCache
input
({
MakeSentence
(
1
).
SerializeAsString
(),
MakeSentence
(
2
).
SerializeAsString
(),
MakeSentence
(
3
).
SerializeAsString
(),
MakeSentence
(
4
).
SerializeAsString
()});
EXPECT_CALL
(
compute_session_
,
GetInputBatchCache
())
.
WillRepeatedly
(
Return
(
&
input
));
EXPECT_THAT
(
component
->
Evaluate
(
&
session_state_
,
&
compute_session_
,
nullptr
),
test
::
IsErrorWithSubstr
(
"Non-singleton batch: got 4 elements"
));
}
// Tests that Evaluate() fails if the adjacency matrix and sentence disagree on
// the number of tokens.
TEST_F
(
SyntaxNetMstSolverComponentTest
,
WrongNumTokens
)
{
const
std
::
vector
<
std
::
vector
<
float
>>
adjacency
=
{{
1.0
,
0.0
,
0.0
,
0.0
},
//
{
0.0
,
1.0
,
0.0
,
0.0
},
//
{
0.0
,
0.0
,
1.0
,
0.0
},
//
{
0.0
,
0.0
,
0.0
,
1.0
}};
// 4-token adjacency matrix with 3-token sentence.
Sentence
sentence
=
MakeSentence
(
3
);
EXPECT_THAT
(
Run
(
MakeGoodSpec
(),
adjacency
,
&
sentence
),
test
::
IsErrorWithSubstr
(
"Sentence size mismatch: expected 4 tokens but got 3"
));
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
Prev
1
…
6
7
8
9
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment