Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
a4bb31d0
Commit
a4bb31d0
authored
May 02, 2018
by
Terry Koo
Browse files
Export @195097388.
parent
dea7ecf6
Changes
294
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3711 additions
and
0 deletions
+3711
-0
research/syntaxnet/dragnn/runtime/sequence_component_transformer.cc
...yntaxnet/dragnn/runtime/sequence_component_transformer.cc
+144
-0
research/syntaxnet/dragnn/runtime/sequence_component_transformer_test.cc
...net/dragnn/runtime/sequence_component_transformer_test.cc
+261
-0
research/syntaxnet/dragnn/runtime/sequence_extractor.cc
research/syntaxnet/dragnn/runtime/sequence_extractor.cc
+75
-0
research/syntaxnet/dragnn/runtime/sequence_extractor.h
research/syntaxnet/dragnn/runtime/sequence_extractor.h
+100
-0
research/syntaxnet/dragnn/runtime/sequence_extractor_test.cc
research/syntaxnet/dragnn/runtime/sequence_extractor_test.cc
+166
-0
research/syntaxnet/dragnn/runtime/sequence_features.cc
research/syntaxnet/dragnn/runtime/sequence_features.cc
+104
-0
research/syntaxnet/dragnn/runtime/sequence_features.h
research/syntaxnet/dragnn/runtime/sequence_features.h
+159
-0
research/syntaxnet/dragnn/runtime/sequence_features_test.cc
research/syntaxnet/dragnn/runtime/sequence_features_test.cc
+346
-0
research/syntaxnet/dragnn/runtime/sequence_linker.cc
research/syntaxnet/dragnn/runtime/sequence_linker.cc
+74
-0
research/syntaxnet/dragnn/runtime/sequence_linker.h
research/syntaxnet/dragnn/runtime/sequence_linker.h
+105
-0
research/syntaxnet/dragnn/runtime/sequence_linker_test.cc
research/syntaxnet/dragnn/runtime/sequence_linker_test.cc
+167
-0
research/syntaxnet/dragnn/runtime/sequence_links.cc
research/syntaxnet/dragnn/runtime/sequence_links.cc
+146
-0
research/syntaxnet/dragnn/runtime/sequence_links.h
research/syntaxnet/dragnn/runtime/sequence_links.h
+169
-0
research/syntaxnet/dragnn/runtime/sequence_links_test.cc
research/syntaxnet/dragnn/runtime/sequence_links_test.cc
+484
-0
research/syntaxnet/dragnn/runtime/sequence_model.cc
research/syntaxnet/dragnn/runtime/sequence_model.cc
+193
-0
research/syntaxnet/dragnn/runtime/sequence_model.h
research/syntaxnet/dragnn/runtime/sequence_model.h
+143
-0
research/syntaxnet/dragnn/runtime/sequence_model_test.cc
research/syntaxnet/dragnn/runtime/sequence_model_test.cc
+550
-0
research/syntaxnet/dragnn/runtime/sequence_predictor.cc
research/syntaxnet/dragnn/runtime/sequence_predictor.cc
+73
-0
research/syntaxnet/dragnn/runtime/sequence_predictor.h
research/syntaxnet/dragnn/runtime/sequence_predictor.h
+94
-0
research/syntaxnet/dragnn/runtime/sequence_predictor_test.cc
research/syntaxnet/dragnn/runtime/sequence_predictor_test.cc
+158
-0
No files found.
Too many changes to show.
To preserve performance only
294 of 294+
files are displayed.
Plain diff
Email patch
research/syntaxnet/dragnn/runtime/sequence_component_transformer.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <string>
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component_transformation.h"
#include "dragnn/runtime/sequence_extractor.h"
#include "dragnn/runtime/sequence_linker.h"
#include "dragnn/runtime/sequence_predictor.h"
#include "dragnn/runtime/transition_system_traits.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Returns true if the |component_spec| has recurrent links.
bool
IsRecurrent
(
const
ComponentSpec
&
component_spec
)
{
for
(
const
LinkedFeatureChannel
&
channel
:
component_spec
.
linked_feature
())
{
if
(
channel
.
source_component
()
==
component_spec
.
name
())
return
true
;
}
return
false
;
}
// Returns the sequence-based version of the |component_type| with specification
// |component_spec|, or an empty string if there is no sequence-based version.
string
GetSequenceComponentType
(
const
string
&
component_type
,
const
ComponentSpec
&
component_spec
)
{
// TODO(googleuser): Implement a SequenceDynamicComponent that can handle
// recurrent links. This may require changes to the NetworkUnit API.
static
const
char
*
kSupportedComponentTypes
[]
=
{
"BulkDynamicComponent"
,
//
"BulkLstmComponent"
,
//
"MyelinDynamicComponent"
,
//
};
for
(
const
char
*
supported_type
:
kSupportedComponentTypes
)
{
if
(
component_type
==
supported_type
)
{
return
tensorflow
::
strings
::
StrCat
(
"Sequence"
,
supported_type
);
}
}
// Also support non-recurrent DynamicComponents. The BulkDynamicComponent
// requires determinism, but the SequenceBulkDynamicComponent does not, so
// it's not sufficient to only upgrade from BulkDynamicComponent.
if
(
component_type
==
"DynamicComponent"
&&
!
IsRecurrent
(
component_spec
))
{
return
"SequenceBulkDynamicComponent"
;
}
return
string
();
}
// Returns the |status| but coerces NOT_FOUND to OK. Sets |found| to false iff
// the |status| was NOT_FOUND.
tensorflow
::
Status
AllowNotFound
(
const
tensorflow
::
Status
&
status
,
bool
*
found
)
{
*
found
=
status
.
code
()
!=
tensorflow
::
error
::
NOT_FOUND
;
return
*
found
?
status
:
tensorflow
::
Status
::
OK
();
}
// Transformer that checks whether a sequence-based component implementation
// could be used and, if compatible, modifies the ComponentSpec accordingly.
class
SequenceComponentTransformer
:
public
ComponentTransformer
{
public:
// Implements ComponentTransformer.
tensorflow
::
Status
Transform
(
const
string
&
component_type
,
ComponentSpec
*
component_spec
)
override
;
};
tensorflow
::
Status
SequenceComponentTransformer
::
Transform
(
const
string
&
component_type
,
ComponentSpec
*
component_spec
)
{
const
int
num_features
=
component_spec
->
fixed_feature_size
()
+
component_spec
->
linked_feature_size
();
if
(
num_features
==
0
)
return
tensorflow
::
Status
::
OK
();
// Look for supporting SequenceExtractors.
bool
found
=
false
;
string
extractor_types
;
for
(
const
FixedFeatureChannel
&
channel
:
component_spec
->
fixed_feature
())
{
string
type
;
TF_RETURN_IF_ERROR
(
AllowNotFound
(
SequenceExtractor
::
Select
(
channel
,
*
component_spec
,
&
type
),
&
found
));
if
(
!
found
)
return
tensorflow
::
Status
::
OK
();
tensorflow
::
strings
::
StrAppend
(
&
extractor_types
,
type
,
","
);
}
if
(
!
extractor_types
.
empty
())
extractor_types
.
pop_back
();
// remove comma
// Look for supporting SequenceLinkers.
string
linker_types
;
for
(
const
LinkedFeatureChannel
&
channel
:
component_spec
->
linked_feature
())
{
string
type
;
TF_RETURN_IF_ERROR
(
AllowNotFound
(
SequenceLinker
::
Select
(
channel
,
*
component_spec
,
&
type
),
&
found
));
if
(
!
found
)
return
tensorflow
::
Status
::
OK
();
tensorflow
::
strings
::
StrAppend
(
&
linker_types
,
type
,
","
);
}
if
(
!
linker_types
.
empty
())
linker_types
.
pop_back
();
// remove comma
// Look for a supporting SequencePredictor, if predictions are necessary.
string
predictor_type
;
if
(
!
TransitionSystemTraits
(
*
component_spec
).
is_deterministic
)
{
TF_RETURN_IF_ERROR
(
AllowNotFound
(
SequencePredictor
::
Select
(
*
component_spec
,
&
predictor_type
),
&
found
));
if
(
!
found
)
return
tensorflow
::
Status
::
OK
();
}
// Look for a supporting sequence-based component type.
const
string
sequence_component_type
=
GetSequenceComponentType
(
component_type
,
*
component_spec
);
if
(
sequence_component_type
.
empty
())
return
tensorflow
::
Status
::
OK
();
// Success; make modifications.
component_spec
->
mutable_backend
()
->
set_registered_name
(
"SequenceBackend"
);
RegisteredModuleSpec
*
builder
=
component_spec
->
mutable_component_builder
();
builder
->
set_registered_name
(
sequence_component_type
);
(
*
builder
->
mutable_parameters
())[
"sequence_extractors"
]
=
extractor_types
;
(
*
builder
->
mutable_parameters
())[
"sequence_linkers"
]
=
linker_types
;
(
*
builder
->
mutable_parameters
())[
"sequence_predictor"
]
=
predictor_type
;
return
tensorflow
::
Status
::
OK
();
}
DRAGNN_RUNTIME_REGISTER_COMPONENT_TRANSFORMER
(
SequenceComponentTransformer
);
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/sequence_component_transformer_test.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <memory>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component_transformation.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/sequence_extractor.h"
#include "dragnn/runtime/sequence_linker.h"
#include "dragnn/runtime/sequence_predictor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Arbitrary supported component type.
constexpr
char
kSupportedComponentType
[]
=
"MyelinDynamicComponent"
;
// Sequence-based version of the component type.
constexpr
char
kTransformedComponentType
[]
=
"SequenceMyelinDynamicComponent"
;
// Trivial extractor that supports components named "supported".
class
SupportIfNamedSupportedExtractor
:
public
SequenceExtractor
{
public:
// Implements SequenceExtractor.
bool
Supports
(
const
FixedFeatureChannel
&
,
const
ComponentSpec
&
component_spec
)
const
override
{
return
component_spec
.
name
()
==
"supported"
;
}
tensorflow
::
Status
Initialize
(
const
FixedFeatureChannel
&
,
const
ComponentSpec
&
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
GetIds
(
InputBatchCache
*
,
std
::
vector
<
int32
>
*
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR
(
SupportIfNamedSupportedExtractor
);
// Trivial extractor that supports components if they have a resource. This is
// used to generate a "multiple supported extractors" conflict.
class
SupportIfHasResourcesExtractor
:
public
SequenceExtractor
{
public:
// Implements SequenceExtractor.
bool
Supports
(
const
FixedFeatureChannel
&
,
const
ComponentSpec
&
component_spec
)
const
override
{
return
component_spec
.
resource_size
()
>
0
;
}
tensorflow
::
Status
Initialize
(
const
FixedFeatureChannel
&
,
const
ComponentSpec
&
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
GetIds
(
InputBatchCache
*
,
std
::
vector
<
int32
>
*
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR
(
SupportIfHasResourcesExtractor
);
// Trivial linker that supports components named "supported".
class
SupportIfNamedSupportedLinker
:
public
SequenceLinker
{
public:
// Implements SequenceLinker.
bool
Supports
(
const
LinkedFeatureChannel
&
,
const
ComponentSpec
&
component_spec
)
const
override
{
return
component_spec
.
name
()
==
"supported"
;
}
tensorflow
::
Status
Initialize
(
const
LinkedFeatureChannel
&
,
const
ComponentSpec
&
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
GetLinks
(
size_t
,
InputBatchCache
*
,
std
::
vector
<
int32
>
*
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER
(
SupportIfNamedSupportedLinker
);
// Trivial predictor that supports components named "supported".
class
SupportIfNamedSupportedPredictor
:
public
SequencePredictor
{
public:
// Implements SequencePredictor.
bool
Supports
(
const
ComponentSpec
&
component_spec
)
const
override
{
return
component_spec
.
name
()
==
"supported"
;
}
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
Predict
(
Matrix
<
float
>
,
InputBatchCache
*
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_PREDICTOR
(
SupportIfNamedSupportedPredictor
);
// Returns a ComponentSpec that is supported by the transformer.
ComponentSpec
MakeSupportedSpec
()
{
ComponentSpec
component_spec
;
component_spec
.
set_name
(
"supported"
);
component_spec
.
set_num_actions
(
10
);
component_spec
.
add_fixed_feature
();
component_spec
.
add_fixed_feature
();
component_spec
.
add_linked_feature
();
component_spec
.
add_linked_feature
();
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
kSupportedComponentType
);
return
component_spec
;
}
// Tests that a compatible spec is modified to use a new backend and component
// builder with SequenceExtractors, SequenceLinkers, and SequencePredictor.
TEST
(
SequenceComponentTransformerTest
,
Compatible
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
ComponentSpec
modified_spec
=
component_spec
;
modified_spec
.
mutable_backend
()
->
set_registered_name
(
"SequenceBackend"
);
modified_spec
.
mutable_component_builder
()
->
set_registered_name
(
kTransformedComponentType
);
modified_spec
.
mutable_component_builder
()
->
mutable_parameters
()
->
insert
(
{
"sequence_extractors"
,
"SupportIfNamedSupportedExtractor,SupportIfNamedSupportedExtractor"
});
modified_spec
.
mutable_component_builder
()
->
mutable_parameters
()
->
insert
(
{
"sequence_linkers"
,
"SupportIfNamedSupportedLinker,SupportIfNamedSupportedLinker"
});
modified_spec
.
mutable_component_builder
()
->
mutable_parameters
()
->
insert
(
{
"sequence_predictor"
,
"SupportIfNamedSupportedPredictor"
});
TF_ASSERT_OK
(
ComponentTransformer
::
ApplyAll
(
&
component_spec
));
EXPECT_THAT
(
component_spec
,
test
::
EqualsProto
(
modified_spec
));
}
// Tests that a compatible deterministic spec is modified to use a new backend
// and component builder with SequenceExtractors and SequenceLinkers only.
TEST
(
SequenceComponentTransformerTest
,
CompatibleNoPredictor
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
set_num_actions
(
1
);
ComponentSpec
modified_spec
=
component_spec
;
modified_spec
.
mutable_backend
()
->
set_registered_name
(
"SequenceBackend"
);
modified_spec
.
mutable_component_builder
()
->
set_registered_name
(
kTransformedComponentType
);
modified_spec
.
mutable_component_builder
()
->
mutable_parameters
()
->
insert
(
{
"sequence_extractors"
,
"SupportIfNamedSupportedExtractor,SupportIfNamedSupportedExtractor"
});
modified_spec
.
mutable_component_builder
()
->
mutable_parameters
()
->
insert
(
{
"sequence_linkers"
,
"SupportIfNamedSupportedLinker,SupportIfNamedSupportedLinker"
});
modified_spec
.
mutable_component_builder
()
->
mutable_parameters
()
->
insert
(
{
"sequence_predictor"
,
""
});
TF_ASSERT_OK
(
ComponentTransformer
::
ApplyAll
(
&
component_spec
));
EXPECT_THAT
(
component_spec
,
test
::
EqualsProto
(
modified_spec
));
}
// Tests that a ComponentSpec with no features is incompatible.
TEST
(
SequenceComponentTransformerTest
,
IncompatibleNoFeatures
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
clear_fixed_feature
();
component_spec
.
clear_linked_feature
();
const
ComponentSpec
unchanged_spec
=
component_spec
;
TF_ASSERT_OK
(
ComponentTransformer
::
ApplyAll
(
&
component_spec
));
EXPECT_THAT
(
component_spec
,
test
::
EqualsProto
(
unchanged_spec
));
}
// Tests that a ComponentSpec with the wrong component builder is incompatible.
TEST
(
SequenceComponentTransformerTest
,
IncompatibleComponentBuilder
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"bad"
);
const
ComponentSpec
unchanged_spec
=
component_spec
;
TF_ASSERT_OK
(
ComponentTransformer
::
ApplyAll
(
&
component_spec
));
EXPECT_THAT
(
component_spec
,
test
::
EqualsProto
(
unchanged_spec
));
}
// Tests that a ComponentSpec is incompatible if it is not supported by any
// SequenceExtractor.
TEST
(
SequenceComponentTransformerTest
,
IncompatibleNoSupportingSequenceExtractor
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
set_name
(
"bad"
);
const
ComponentSpec
unchanged_spec
=
component_spec
;
TF_ASSERT_OK
(
ComponentTransformer
::
ApplyAll
(
&
component_spec
));
EXPECT_THAT
(
component_spec
,
test
::
EqualsProto
(
unchanged_spec
));
}
// Tests that a ComponentSpec fails if multiple SequenceExtractors support it.
TEST
(
SequenceComponentTransformerTest
,
FailIfMultipleSupportingSequenceExtractors
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
add_resource
();
// triggers SupportIfHasResourcesExtractor
EXPECT_THAT
(
ComponentTransformer
::
ApplyAll
(
&
component_spec
),
test
::
IsErrorWithSubstr
(
"Multiple SequenceExtractors support channel"
));
}
// Tests that a DynamicComponent is not upgraded if it is recurrent.
TEST
(
SequenceComponentTransformerTest
,
RecurrentDynamicComponent
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"DynamicComponent"
);
component_spec
.
mutable_linked_feature
(
0
)
->
set_source_component
(
component_spec
.
name
());
const
ComponentSpec
unchanged_spec
=
component_spec
;
TF_ASSERT_OK
(
ComponentTransformer
::
ApplyAll
(
&
component_spec
));
EXPECT_THAT
(
component_spec
,
test
::
EqualsProto
(
unchanged_spec
));
}
// Tests that a DynamicComponent is upgraded to SequenceBulkDynamicComponent if
// it is non-recurrent.
TEST
(
SequenceComponentTransformerTest
,
NonRecurrentDynamicComponent
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"DynamicComponent"
);
ComponentSpec
modified_spec
=
component_spec
;
modified_spec
.
mutable_backend
()
->
set_registered_name
(
"SequenceBackend"
);
modified_spec
.
mutable_component_builder
()
->
set_registered_name
(
"SequenceBulkDynamicComponent"
);
modified_spec
.
mutable_component_builder
()
->
mutable_parameters
()
->
insert
(
{
"sequence_extractors"
,
"SupportIfNamedSupportedExtractor,SupportIfNamedSupportedExtractor"
});
modified_spec
.
mutable_component_builder
()
->
mutable_parameters
()
->
insert
(
{
"sequence_linkers"
,
"SupportIfNamedSupportedLinker,SupportIfNamedSupportedLinker"
});
modified_spec
.
mutable_component_builder
()
->
mutable_parameters
()
->
insert
(
{
"sequence_predictor"
,
"SupportIfNamedSupportedPredictor"
});
TF_ASSERT_OK
(
ComponentTransformer
::
ApplyAll
(
&
component_spec
));
EXPECT_THAT
(
component_spec
,
test
::
EqualsProto
(
modified_spec
));
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/sequence_extractor.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_extractor.h"
#include "tensorflow/core/lib/core/errors.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
tensorflow
::
Status
SequenceExtractor
::
Select
(
const
FixedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
,
string
*
name
)
{
string
supporting_name
;
for
(
const
Registry
::
Registrar
*
registrar
=
registry
()
->
components
;
registrar
!=
nullptr
;
registrar
=
registrar
->
next
())
{
Factory
*
factory_function
=
registrar
->
object
();
std
::
unique_ptr
<
SequenceExtractor
>
current_extractor
(
factory_function
());
if
(
!
current_extractor
->
Supports
(
channel
,
component_spec
))
continue
;
if
(
!
supporting_name
.
empty
())
{
return
tensorflow
::
errors
::
Internal
(
"Multiple SequenceExtractors support channel "
,
channel
.
ShortDebugString
(),
" of ComponentSpec ("
,
supporting_name
,
" and "
,
registrar
->
name
(),
"): "
,
component_spec
.
ShortDebugString
());
}
supporting_name
=
registrar
->
name
();
}
if
(
supporting_name
.
empty
())
{
return
tensorflow
::
errors
::
NotFound
(
"No SequenceExtractor supports channel "
,
channel
.
ShortDebugString
(),
" of ComponentSpec: "
,
component_spec
.
ShortDebugString
());
}
// Success; make modifications.
*
name
=
supporting_name
;
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
SequenceExtractor
::
New
(
const
string
&
name
,
const
FixedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
,
std
::
unique_ptr
<
SequenceExtractor
>
*
extractor
)
{
std
::
unique_ptr
<
SequenceExtractor
>
matching_extractor
;
TF_RETURN_IF_ERROR
(
SequenceExtractor
::
CreateOrError
(
name
,
&
matching_extractor
));
TF_RETURN_IF_ERROR
(
matching_extractor
->
Initialize
(
channel
,
component_spec
));
// Success; make modifications.
*
extractor
=
std
::
move
(
matching_extractor
);
return
tensorflow
::
Status
::
OK
();
}
}
// namespace runtime
}
// namespace dragnn
REGISTER_SYNTAXNET_CLASS_REGISTRY
(
"DRAGNN Runtime Sequence Extractor"
,
dragnn
::
runtime
::
SequenceExtractor
);
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/sequence_extractor.h
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_SEQUENCE_EXTRACTOR_H_
#define DRAGNN_RUNTIME_SEQUENCE_EXTRACTOR_H_
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/protos/spec.pb.h"
#include "syntaxnet/base.h"
#include "syntaxnet/registry.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// Interface for feature extraction for sequence inputs.
//
// This extractor can be used to avoid ComputeSession overhead in simple cases;
// for example, extracting a sequence of character or word IDs for an LSTM.
class
SequenceExtractor
:
public
RegisterableClass
<
SequenceExtractor
>
{
public:
// Sets |extractor| to an instance of the subclass named |name| initialized
// from the |channel| of the |component_spec|. On error, returns non-OK and
// modifies nothing.
static
tensorflow
::
Status
New
(
const
string
&
name
,
const
FixedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
,
std
::
unique_ptr
<
SequenceExtractor
>
*
extractor
);
SequenceExtractor
(
const
SequenceExtractor
&
)
=
delete
;
SequenceExtractor
&
operator
=
(
const
SequenceExtractor
&
)
=
delete
;
virtual
~
SequenceExtractor
()
=
default
;
// Sets |name| to the registered name of the SequenceExtractor that supports
// the |channel| of the |component_spec|. On error, returns non-OK and
// modifies nothing. The returned statuses include:
// * OK: If a supporting SequenceExtractor was found.
// * INTERNAL: If an error occurred while searching for a compatible match.
// * NOT_FOUND: If the search was error-free, but no compatible match was
// found.
static
tensorflow
::
Status
Select
(
const
FixedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
,
string
*
name
);
// Overwrites |ids| with the sequence of features extracted from the |input|.
// On error, returns non-OK.
virtual
tensorflow
::
Status
GetIds
(
InputBatchCache
*
input
,
std
::
vector
<
int32
>
*
ids
)
const
=
0
;
protected:
SequenceExtractor
()
=
default
;
private:
// Helps prevent use of the Create() method; use New() instead.
using
RegisterableClass
<
SequenceExtractor
>::
Create
;
// Returns true if this supports the |channel| of the |component_spec|.
// Implementations must coordinate to ensure that at most one supports any
// given |component_spec|.
virtual
bool
Supports
(
const
FixedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
)
const
=
0
;
// Initializes this from the |channel| of the |component_spec|. On error,
// returns non-OK.
virtual
tensorflow
::
Status
Initialize
(
const
FixedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
)
=
0
;
};
}
// namespace runtime
}
// namespace dragnn
DECLARE_SYNTAXNET_CLASS_REGISTRY
(
"DRAGNN Runtime Sequence Extractor"
,
dragnn
::
runtime
::
SequenceExtractor
);
}
// namespace syntaxnet
#define DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR(subclass) \
REGISTER_SYNTAXNET_CLASS_COMPONENT( \
::syntaxnet::dragnn::runtime::SequenceExtractor, #subclass, subclass)
#endif // DRAGNN_RUNTIME_SEQUENCE_EXTRACTOR_H_
research/syntaxnet/dragnn/runtime/sequence_extractor_test.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_extractor.h"
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Supports components named "success" and initializes successfully.
class
Success
:
public
SequenceExtractor
{
public:
// Implements SequenceExtractor.
bool
Supports
(
const
FixedFeatureChannel
&
,
const
ComponentSpec
&
component_spec
)
const
override
{
return
component_spec
.
name
()
==
"success"
;
}
tensorflow
::
Status
Initialize
(
const
FixedFeatureChannel
&
,
const
ComponentSpec
&
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
GetIds
(
InputBatchCache
*
,
std
::
vector
<
int32
>
*
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR
(
Success
);
// Supports components named "failure" and fails to initialize.
class
Failure
:
public
SequenceExtractor
{
public:
// Implements SequenceExtractor.
bool
Supports
(
const
FixedFeatureChannel
&
,
const
ComponentSpec
&
component_spec
)
const
override
{
return
component_spec
.
name
()
==
"failure"
;
}
tensorflow
::
Status
Initialize
(
const
FixedFeatureChannel
&
,
const
ComponentSpec
&
)
override
{
return
tensorflow
::
errors
::
Internal
(
"Boom!"
);
}
tensorflow
::
Status
GetIds
(
InputBatchCache
*
,
std
::
vector
<
int32
>
*
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR
(
Failure
);
// Supports components named "duplicate" and initializes successfully.
class
Duplicate
:
public
SequenceExtractor
{
public:
// Implements SequenceExtractor.
bool
Supports
(
const
FixedFeatureChannel
&
,
const
ComponentSpec
&
component_spec
)
const
override
{
return
component_spec
.
name
()
==
"duplicate"
;
}
tensorflow
::
Status
Initialize
(
const
FixedFeatureChannel
&
,
const
ComponentSpec
&
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
GetIds
(
InputBatchCache
*
,
std
::
vector
<
int32
>
*
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR
(
Duplicate
);
// Duplicate of the above.
using
Duplicate2
=
Duplicate
;
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR
(
Duplicate2
);
// Tests that a component can be successfully created.
TEST
(
SequenceExtractorTest
,
Success
)
{
string
name
;
std
::
unique_ptr
<
SequenceExtractor
>
extractor
;
ComponentSpec
component_spec
;
component_spec
.
set_name
(
"success"
);
TF_ASSERT_OK
(
SequenceExtractor
::
Select
({},
component_spec
,
&
name
));
ASSERT_EQ
(
name
,
"Success"
);
TF_EXPECT_OK
(
SequenceExtractor
::
New
(
name
,
{},
component_spec
,
&
extractor
));
EXPECT_NE
(
extractor
,
nullptr
);
}
// Tests that errors in Initialize() are reported.
TEST
(
SequenceExtractorTest
,
FailToInitialize
)
{
string
name
;
std
::
unique_ptr
<
SequenceExtractor
>
extractor
;
ComponentSpec
component_spec
;
component_spec
.
set_name
(
"failure"
);
TF_ASSERT_OK
(
SequenceExtractor
::
Select
({},
component_spec
,
&
name
));
EXPECT_EQ
(
name
,
"Failure"
);
EXPECT_THAT
(
SequenceExtractor
::
New
(
name
,
{},
component_spec
,
&
extractor
),
test
::
IsErrorWithSubstr
(
"Boom!"
));
EXPECT_EQ
(
extractor
,
nullptr
);
}
// Tests that unsupported specs are reported as NOT_FOUND errors.
TEST
(
SequenceExtractorTest
,
UnsupportedSpec
)
{
string
name
=
"not overwritten"
;
ComponentSpec
component_spec
;
component_spec
.
set_name
(
"unsupported"
);
EXPECT_THAT
(
SequenceExtractor
::
Select
({},
component_spec
,
&
name
),
test
::
IsErrorWithCodeAndSubstr
(
tensorflow
::
error
::
NOT_FOUND
,
"No SequenceExtractor supports channel"
));
EXPECT_EQ
(
name
,
"not overwritten"
);
}
// Tests that unsupported subclass names are reported as errors.
TEST
(
SequenceExtractorTest
,
UnsupportedSubclass
)
{
std
::
unique_ptr
<
SequenceExtractor
>
extractor
;
ComponentSpec
component_spec
;
EXPECT_THAT
(
SequenceExtractor
::
New
(
"Unsupported"
,
{},
component_spec
,
&
extractor
),
test
::
IsErrorWithSubstr
(
"Unknown DRAGNN Runtime Sequence Extractor"
));
EXPECT_EQ
(
extractor
,
nullptr
);
}
// Tests that multiple supporting extractors are reported as INTERNAL errors.
TEST
(
SequenceExtractorTest
,
Duplicate
)
{
string
name
=
"not overwritten"
;
ComponentSpec
component_spec
;
component_spec
.
set_name
(
"duplicate"
);
EXPECT_THAT
(
SequenceExtractor
::
Select
({},
component_spec
,
&
name
),
test
::
IsErrorWithCodeAndSubstr
(
tensorflow
::
error
::
INTERNAL
,
"Multiple SequenceExtractors support channel"
));
EXPECT_EQ
(
name
,
"not overwritten"
);
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/sequence_features.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_features.h"
#include <utility>
#include "tensorflow/core/lib/core/errors.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
tensorflow
::
Status
SequenceFeatureManager
::
Reset
(
const
FixedEmbeddingManager
*
fixed_embedding_manager
,
const
ComponentSpec
&
component_spec
,
const
std
::
vector
<
string
>
&
sequence_extractor_types
)
{
const
size_t
num_channels
=
fixed_embedding_manager
->
channel_configs_
.
size
();
if
(
component_spec
.
fixed_feature_size
()
!=
num_channels
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Channel mismatch between FixedEmbeddingManager ("
,
num_channels
,
") and ComponentSpec ("
,
component_spec
.
fixed_feature_size
(),
")"
);
}
if
(
sequence_extractor_types
.
size
()
!=
num_channels
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Channel mismatch between FixedEmbeddingManager ("
,
num_channels
,
") and SequenceExtractors ("
,
sequence_extractor_types
.
size
(),
")"
);
}
for
(
const
FixedFeatureChannel
&
channel
:
component_spec
.
fixed_feature
())
{
if
(
channel
.
size
()
>
1
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Multi-embedding fixed features are not supported for channel: "
,
channel
.
ShortDebugString
());
}
}
std
::
vector
<
ChannelConfig
>
local_configs
;
// avoid modification on error
for
(
size_t
channel_id
=
0
;
channel_id
<
num_channels
;
++
channel_id
)
{
local_configs
.
emplace_back
();
ChannelConfig
&
channel_config
=
local_configs
.
back
();
const
FixedEmbeddingManager
::
ChannelConfig
&
wrapped_config
=
fixed_embedding_manager
->
channel_configs_
[
channel_id
];
channel_config
.
is_embedded
=
wrapped_config
.
is_embedded
;
channel_config
.
embedding_matrix
=
wrapped_config
.
embedding_matrix
;
TF_RETURN_IF_ERROR
(
SequenceExtractor
::
New
(
sequence_extractor_types
[
channel_id
],
component_spec
.
fixed_feature
(
channel_id
),
component_spec
,
&
channel_config
.
extractor
));
}
// Success; make modifications.
zeros_
=
fixed_embedding_manager
->
zeros_
.
view
();
channel_configs_
=
std
::
move
(
local_configs
);
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
SequenceFeatures
::
Reset
(
const
SequenceFeatureManager
*
manager
,
InputBatchCache
*
input
)
{
manager_
=
manager
;
zeros_
=
manager
->
zeros_
;
num_channels_
=
manager
->
channel_configs_
.
size
();
num_steps_
=
0
;
// Make sure |channels_| is big enough. Note that |channels_| never shrinks,
// so the Channel.ids sub-vector is never deallocated.
if
(
num_channels_
>
channels_
.
size
())
channels_
.
resize
(
num_channels_
);
for
(
int
channel_id
=
0
;
channel_id
<
num_channels_
;
++
channel_id
)
{
Channel
&
channel
=
channels_
[
channel_id
];
const
SequenceFeatureManager
::
ChannelConfig
&
channel_config
=
manager
->
channel_configs_
[
channel_id
];
channel
.
embedding_matrix
=
channel_config
.
embedding_matrix
;
TF_RETURN_IF_ERROR
(
channel_config
.
extractor
->
GetIds
(
input
,
&
channel
.
ids
));
if
(
channel_id
==
0
)
{
num_steps_
=
channel
.
ids
.
size
();
}
else
if
(
channel
.
ids
.
size
()
!=
num_steps_
)
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Inconsistent feature sequence lengths at channel ID "
,
channel_id
,
": got "
,
channel
.
ids
.
size
(),
" but expected "
,
num_steps_
);
}
}
return
tensorflow
::
Status
::
OK
();
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/sequence_features.h
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Utils for configuring and extracting fixed embeddings for sequence-based
// models. Analogous to FixedEmbeddingManager and FixedEmbeddings, but uses
// SequenceExtractor instead of ComputeSession.
#ifndef DRAGNN_RUNTIME_SEQUENCE_FEATURES_H_
#define DRAGNN_RUNTIME_SEQUENCE_FEATURES_H_
#include <stddef.h>
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/fixed_embeddings.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/sequence_extractor.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// Manager for fixed embeddings for sequence-based models. This is a wrapper
// around the FixedEmbeddingManager.
class
SequenceFeatureManager
{
public:
// Creates an empty manager.
SequenceFeatureManager
()
=
default
;
// Resets this to wrap the |fixed_embedding_manager|, which must outlive this.
// The |sequence_extractor_types| should name one SequenceExtractor subclass
// per channel; e.g., "SyntaxNetCharacterSequenceExtractor". This initializes
// each SequenceExtractor from the |component_spec|. On error, returns non-OK
// and does not modify this.
tensorflow
::
Status
Reset
(
const
FixedEmbeddingManager
*
fixed_embedding_manager
,
const
ComponentSpec
&
component_spec
,
const
std
::
vector
<
string
>
&
sequence_extractor_types
);
// Accessors.
size_t
num_channels
()
const
{
return
channel_configs_
.
size
();
}
private:
friend
class
SequenceFeatures
;
// Configuration for a single fixed embedding channel.
struct
ChannelConfig
{
// Whether this channel is embedded.
bool
is_embedded
=
true
;
// Embedding matrix of this channel. Only used if |is_embedded| is true.
Matrix
<
float
>
embedding_matrix
;
// Extractor for sequences of feature IDs.
std
::
unique_ptr
<
SequenceExtractor
>
extractor
;
};
// Array of zeros that can be substituted for missing feature IDs. This is a
// reference to the corresponding array in the FixedEmbeddingManager.
AlignedView
zeros_
;
// Ordered list of configurations for each channel.
std
::
vector
<
ChannelConfig
>
channel_configs_
;
};
// A set of fixed embeddings for a sequence-based model. Configured by a
// SequenceFeatureManager.
class
SequenceFeatures
{
public:
// Creates an empty set of embeddings.
SequenceFeatures
()
=
default
;
// Resets this to the sequences of fixed features managed by the |manager| on
// the |input|. The |manager| must live until this is destroyed or Reset(),
// and should not be modified during that time. On error, returns non-OK.
tensorflow
::
Status
Reset
(
const
SequenceFeatureManager
*
manager
,
InputBatchCache
*
input
);
// Returns the feature ID or embedding for the |target_index|'th element of
// the |channel_id|'th channel. Each method is only valid for a non-embedded
// or embedded channel, respectively.
int32
GetId
(
size_t
channel_id
,
size_t
target_index
)
const
;
Vector
<
float
>
GetEmbedding
(
size_t
channel_id
,
size_t
target_index
)
const
;
// Accessors.
size_t
num_channels
()
const
{
return
num_channels_
;
}
size_t
num_steps
()
const
{
return
num_steps_
;
}
private:
// Data associated with a single fixed embedding channel.
struct
Channel
{
// Embedding matrix of this channel. Only used for embedded channels.
Matrix
<
float
>
embedding_matrix
;
// Feature IDs for each step.
std
::
vector
<
int32
>
ids
;
};
// Manager from the most recent Reset().
const
SequenceFeatureManager
*
manager_
=
nullptr
;
// Zero vector from the most recent Reset().
AlignedView
zeros_
;
// Number of channels and steps from the most recent Reset().
size_t
num_channels_
=
0
;
size_t
num_steps_
=
0
;
// Ordered list of fixed embedding channels. This may contain more than
// |num_channels_| entries, to avoid deallocation/reallocation cycles, but
// only the first |num_channels_| entries are valid.
std
::
vector
<
Channel
>
channels_
;
};
// Implementation details below.
inline
int32
SequenceFeatures
::
GetId
(
size_t
channel_id
,
size_t
target_index
)
const
{
DCHECK_LT
(
channel_id
,
num_channels
());
DCHECK_LT
(
target_index
,
num_steps
());
DCHECK
(
!
manager_
->
channel_configs_
[
channel_id
].
is_embedded
);
const
Channel
&
channel
=
channels_
[
channel_id
];
return
channel
.
ids
[
target_index
];
}
inline
Vector
<
float
>
SequenceFeatures
::
GetEmbedding
(
size_t
channel_id
,
size_t
target_index
)
const
{
DCHECK_LT
(
channel_id
,
num_channels
());
DCHECK_LT
(
target_index
,
num_steps
());
DCHECK
(
manager_
->
channel_configs_
[
channel_id
].
is_embedded
);
const
Channel
&
channel
=
channels_
[
channel_id
];
const
int32
id
=
channel
.
ids
[
target_index
];
return
id
<
0
?
Vector
<
float
>
(
zeros_
,
channel
.
embedding_matrix
.
num_columns
())
:
channel
.
embedding_matrix
.
row
(
id
);
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_SEQUENCE_FEATURES_H_
research/syntaxnet/dragnn/runtime/sequence_features_test.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_features.h"
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/fixed_embeddings.h"
#include "dragnn/runtime/sequence_extractor.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Number of transition steps to take in each component in the network.
const
size_t
kNumSteps
=
10
;
// A working one-channel ComponentSpec. This is intentionally identical to the
// first channel of |kMultiSpec|, so they can use the same embedding matrix.
const
char
kSingleSpec
[]
=
R"(fixed_feature {
vocabulary_size: 13
embedding_dim: 11
size: 1
})"
;
const
size_t
kSingleRows
=
13
;
const
size_t
kSingleColumns
=
11
;
constexpr
float
kSingleValue
=
1.25
;
// A working multi-channel ComponentSpec.
const
char
kMultiSpec
[]
=
R"(fixed_feature {
vocabulary_size: 13
embedding_dim: 11
size: 1
}
fixed_feature {
embedding_dim: -1
size: 1
}
fixed_feature {
embedding_dim: -1
size: 1
})"
;
// Fails to initialize.
class
FailToInitialize
:
public
SequenceExtractor
{
public:
// Implements SequenceExtractor.
bool
Supports
(
const
FixedFeatureChannel
&
,
const
ComponentSpec
&
component_spec
)
const
override
{
LOG
(
FATAL
)
<<
"Should never be called."
;
}
tensorflow
::
Status
Initialize
(
const
FixedFeatureChannel
&
,
const
ComponentSpec
&
)
override
{
return
tensorflow
::
errors
::
Internal
(
"No initialization for you!"
);
}
tensorflow
::
Status
GetIds
(
InputBatchCache
*
,
std
::
vector
<
int32
>
*
)
const
override
{
LOG
(
FATAL
)
<<
"Should never be called."
;
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR
(
FailToInitialize
);
// Initializes OK, then fails to extract features.
class
FailToGetIds
:
public
FailToInitialize
{
public:
// Implements SequenceExtractor.
tensorflow
::
Status
Initialize
(
const
FixedFeatureChannel
&
,
const
ComponentSpec
&
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
GetIds
(
InputBatchCache
*
,
std
::
vector
<
int32
>
*
)
const
override
{
return
tensorflow
::
errors
::
Internal
(
"No features for you!"
);
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR
(
FailToGetIds
);
// Initializes OK and extracts the previous step.
class
ExtractPrevious
:
public
FailToGetIds
{
public:
// Implements SequenceExtractor.
tensorflow
::
Status
GetIds
(
InputBatchCache
*
,
std
::
vector
<
int32
>
*
ids
)
const
override
{
ids
->
resize
(
kNumSteps
);
for
(
int
i
=
0
;
i
<
kNumSteps
;
++
i
)
(
*
ids
)[
i
]
=
i
-
1
;
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR
(
ExtractPrevious
);
// Initializes OK but produces the wrong number of features.
class
WrongNumberOfIds
:
public
FailToGetIds
{
public:
// Implements SequenceExtractor.
tensorflow
::
Status
GetIds
(
InputBatchCache
*
input
,
std
::
vector
<
int32
>
*
ids
)
const
override
{
ids
->
resize
(
kNumSteps
+
1
);
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR
(
WrongNumberOfIds
);
class
SequenceFeatureManagerTest
:
public
NetworkTestBase
{
protected:
// Creates a SequenceFeatureManager and returns the result of Reset()-ing it
// using the |component_spec_text|.
tensorflow
::
Status
ResetManager
(
const
string
&
component_spec_text
,
const
std
::
vector
<
string
>
&
sequence_extractor_types
)
{
ComponentSpec
component_spec
;
CHECK
(
TextFormat
::
ParseFromString
(
component_spec_text
,
&
component_spec
));
component_spec
.
set_name
(
kTestComponentName
);
AddFixedEmbeddingMatrix
(
0
,
kSingleRows
,
kSingleColumns
,
kSingleValue
);
AddComponent
(
kTestComponentName
);
TF_RETURN_IF_ERROR
(
fixed_embedding_manager_
.
Reset
(
component_spec
,
&
variable_store_
,
&
network_state_manager_
));
return
manager_
.
Reset
(
&
fixed_embedding_manager_
,
component_spec
,
sequence_extractor_types
);
}
FixedEmbeddingManager
fixed_embedding_manager_
;
SequenceFeatureManager
manager_
;
};
// Tests that SequenceFeatureManager is empty by default.
TEST_F
(
SequenceFeatureManagerTest
,
EmptyByDefault
)
{
EXPECT_EQ
(
manager_
.
num_channels
(),
0
);
}
// Tests that SequenceFeatureManager is empty when reset to an empty spec.
TEST_F
(
SequenceFeatureManagerTest
,
EmptySpec
)
{
TF_EXPECT_OK
(
ResetManager
(
""
,
{}));
EXPECT_EQ
(
manager_
.
num_channels
(),
0
);
}
// Tests that SequenceFeatureManager works with a single channel.
TEST_F
(
SequenceFeatureManagerTest
,
OneChannel
)
{
TF_EXPECT_OK
(
ResetManager
(
kSingleSpec
,
{
"ExtractPrevious"
}));
EXPECT_EQ
(
manager_
.
num_channels
(),
1
);
}
// Tests that SequenceFeatureManager works with multiple channels.
TEST_F
(
SequenceFeatureManagerTest
,
MultipleChannels
)
{
TF_EXPECT_OK
(
ResetManager
(
kMultiSpec
,
{
"ExtractPrevious"
,
"ExtractPrevious"
,
"ExtractPrevious"
}));
EXPECT_EQ
(
manager_
.
num_channels
(),
3
);
}
// Tests that SequenceFeatureManager fails if the FixedEmbeddingManager and
// ComponentSpec are mismatched.
TEST_F
(
SequenceFeatureManagerTest
,
MismatchedFixedManagerAndComponentSpec
)
{
ComponentSpec
component_spec
;
CHECK
(
TextFormat
::
ParseFromString
(
kMultiSpec
,
&
component_spec
));
component_spec
.
set_name
(
kTestComponentName
);
AddFixedEmbeddingMatrix
(
0
,
kSingleRows
,
kSingleColumns
,
kSingleValue
);
AddComponent
(
kTestComponentName
);
TF_ASSERT_OK
(
fixed_embedding_manager_
.
Reset
(
component_spec
,
&
variable_store_
,
&
network_state_manager_
));
// Remove one fixed feature, resulting in a mismatch.
component_spec
.
mutable_fixed_feature
()
->
RemoveLast
();
EXPECT_THAT
(
manager_
.
Reset
(
&
fixed_embedding_manager_
,
component_spec
,
{
"ExtractPrevious"
,
"ExtractPrevious"
,
"ExtractPrevious"
}),
test
::
IsErrorWithSubstr
(
"Channel mismatch between FixedEmbeddingManager "
"(3) and ComponentSpec (2)"
));
}
// Tests that SequenceFeatureManager fails if the FixedEmbeddingManager and
// SequenceExtractors are mismatched.
TEST_F
(
SequenceFeatureManagerTest
,
MismatchedFixedManagerAndSequenceExtractors
)
{
EXPECT_THAT
(
ResetManager
(
kMultiSpec
,
{
"ExtractPrevious"
,
"ExtractPrevious"
}),
test
::
IsErrorWithSubstr
(
"Channel mismatch between FixedEmbeddingManager "
"(3) and SequenceExtractors (2)"
));
}
// Tests that SequenceFeatureManager fails if a channel has multiple embeddings.
TEST_F
(
SequenceFeatureManagerTest
,
UnsupportedMultiEmbeddingChannel
)
{
const
string
kBadSpec
=
R"(fixed_feature {
vocabulary_size: 13
embedding_dim: 11
size: 2 # bad
})"
;
EXPECT_THAT
(
ResetManager
(
kBadSpec
,
{
"ExtractPrevious"
}),
test
::
IsErrorWithSubstr
(
"Multi-embedding fixed features are not supported"
));
}
// Tests that SequenceFeatureManager fails if one of the SequenceExtractors
// fails to initialize.
TEST_F
(
SequenceFeatureManagerTest
,
FailToInitializeSequenceExtractor
)
{
EXPECT_THAT
(
ResetManager
(
kMultiSpec
,
{
"ExtractPrevious"
,
"FailToInitialize"
,
"ExtractPrevious"
}),
test
::
IsErrorWithSubstr
(
"No initialization for you!"
));
}
// Tests that SequenceFeatureManager is OK even if the SequenceExtractors would
// fail in GetIds().
TEST_F
(
SequenceFeatureManagerTest
,
ManagerDoesntCareAboutGetIds
)
{
TF_EXPECT_OK
(
ResetManager
(
kMultiSpec
,
{
"FailToGetIds"
,
"FailToGetIds"
,
"FailToGetIds"
}));
}
class
SequenceFeaturesTest
:
public
SequenceFeatureManagerTest
{
protected:
// Resets the |sequence_features_| on the |manager_| and |input_batch_cache_|
// and returns the resulting status.
tensorflow
::
Status
ResetFeatures
()
{
return
sequence_features_
.
Reset
(
&
manager_
,
&
input_batch_cache_
);
}
InputBatchCache
input_batch_cache_
;
SequenceFeatures
sequence_features_
;
};
// Tests that SequenceFeatures is empty by default.
TEST_F
(
SequenceFeaturesTest
,
EmptyByDefault
)
{
EXPECT_EQ
(
sequence_features_
.
num_channels
(),
0
);
EXPECT_EQ
(
sequence_features_
.
num_steps
(),
0
);
}
// Tests that SequenceFeatures is empty when reset by an empty manager.
TEST_F
(
SequenceFeaturesTest
,
EmptyManager
)
{
TF_ASSERT_OK
(
ResetManager
(
""
,
{}));
TF_EXPECT_OK
(
ResetFeatures
());
EXPECT_EQ
(
sequence_features_
.
num_channels
(),
0
);
EXPECT_EQ
(
sequence_features_
.
num_steps
(),
0
);
}
// Tests that SequenceFeatures fails when one of the SequenceExtractors fails.
TEST_F
(
SequenceFeaturesTest
,
FailToGetIds
)
{
TF_ASSERT_OK
(
ResetManager
(
kMultiSpec
,
{
"ExtractPrevious"
,
"ExtractPrevious"
,
"FailToGetIds"
}));
EXPECT_THAT
(
ResetFeatures
(),
test
::
IsErrorWithSubstr
(
"No features for you!"
));
}
// Tests that SequenceFeatures fails when the SequenceExtractors produce
// different numbers of features.
TEST_F
(
SequenceFeaturesTest
,
MismatchedNumbersOfFeatures
)
{
TF_ASSERT_OK
(
ResetManager
(
kMultiSpec
,
{
"ExtractPrevious"
,
"ExtractPrevious"
,
"WrongNumberOfIds"
}));
EXPECT_THAT
(
ResetFeatures
(),
test
::
IsErrorWithSubstr
(
"Inconsistent feature sequence lengths at "
"channel ID 2: got 11 but expected 10"
));
}
// Tests that SequenceFeatures works as expected on one channel.
TEST_F
(
SequenceFeaturesTest
,
SingleChannel
)
{
TF_ASSERT_OK
(
ResetManager
(
kSingleSpec
,
{
"ExtractPrevious"
}));
TF_ASSERT_OK
(
ResetFeatures
());
ASSERT_EQ
(
sequence_features_
.
num_channels
(),
1
);
ASSERT_EQ
(
sequence_features_
.
num_steps
(),
kNumSteps
);
// ExtractPrevious extracts -1 for the 0'th target index, which indicates a
// missing ID and should be mapped to a zero vector.
ExpectVector
(
sequence_features_
.
GetEmbedding
(
0
,
0
),
kSingleColumns
,
0.0
);
EXPECT_DEBUG_DEATH
(
sequence_features_
.
GetId
(
0
,
0
),
"is_embedded"
);
// The remaining feature IDs map to valid embedding rows.
for
(
int
i
=
1
;
i
<
kNumSteps
;
++
i
)
{
ExpectVector
(
sequence_features_
.
GetEmbedding
(
0
,
i
),
kSingleColumns
,
kSingleValue
);
EXPECT_DEBUG_DEATH
(
sequence_features_
.
GetId
(
0
,
i
),
"is_embedded"
);
}
}
// Tests that SequenceFeatures works as expected on multiple channels.
TEST_F
(
SequenceFeaturesTest
,
ManyChannels
)
{
TF_ASSERT_OK
(
ResetManager
(
kMultiSpec
,
{
"ExtractPrevious"
,
"ExtractPrevious"
,
"ExtractPrevious"
}));
TF_ASSERT_OK
(
ResetFeatures
());
ASSERT_EQ
(
sequence_features_
.
num_channels
(),
3
);
ASSERT_EQ
(
sequence_features_
.
num_steps
(),
kNumSteps
);
// ExtractPrevious extracts -1 for the 0'th target index, which indicates a
// missing ID and should be mapped to a zero vector.
ExpectVector
(
sequence_features_
.
GetEmbedding
(
0
,
0
),
kSingleColumns
,
0.0
);
EXPECT_EQ
(
sequence_features_
.
GetId
(
1
,
0
),
-
1
);
EXPECT_EQ
(
sequence_features_
.
GetId
(
2
,
0
),
-
1
);
EXPECT_DEBUG_DEATH
(
sequence_features_
.
GetId
(
0
,
0
),
"is_embedded"
);
EXPECT_DEBUG_DEATH
(
sequence_features_
.
GetEmbedding
(
1
,
0
),
"is_embedded"
);
EXPECT_DEBUG_DEATH
(
sequence_features_
.
GetEmbedding
(
2
,
0
),
"is_embedded"
);
// The remaining features point to the previous item.
for
(
int
i
=
1
;
i
<
kNumSteps
;
++
i
)
{
ExpectVector
(
sequence_features_
.
GetEmbedding
(
0
,
i
),
kSingleColumns
,
kSingleValue
);
EXPECT_EQ
(
sequence_features_
.
GetId
(
1
,
i
),
i
-
1
);
EXPECT_EQ
(
sequence_features_
.
GetId
(
2
,
i
),
i
-
1
);
EXPECT_DEBUG_DEATH
(
sequence_features_
.
GetId
(
0
,
i
),
"is_embedded"
);
EXPECT_DEBUG_DEATH
(
sequence_features_
.
GetEmbedding
(
1
,
i
),
"is_embedded"
);
EXPECT_DEBUG_DEATH
(
sequence_features_
.
GetEmbedding
(
2
,
i
),
"is_embedded"
);
}
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/sequence_linker.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_linker.h"
#include "tensorflow/core/lib/core/errors.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
tensorflow
::
Status
SequenceLinker
::
Select
(
const
LinkedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
,
string
*
name
)
{
string
supporting_name
;
for
(
const
Registry
::
Registrar
*
registrar
=
registry
()
->
components
;
registrar
!=
nullptr
;
registrar
=
registrar
->
next
())
{
Factory
*
factory_function
=
registrar
->
object
();
std
::
unique_ptr
<
SequenceLinker
>
current_linker
(
factory_function
());
if
(
!
current_linker
->
Supports
(
channel
,
component_spec
))
continue
;
if
(
!
supporting_name
.
empty
())
{
return
tensorflow
::
errors
::
Internal
(
"Multiple SequenceLinkers support channel "
,
channel
.
ShortDebugString
(),
" of ComponentSpec ("
,
supporting_name
,
" and "
,
registrar
->
name
(),
"): "
,
component_spec
.
ShortDebugString
());
}
supporting_name
=
registrar
->
name
();
}
if
(
supporting_name
.
empty
())
{
return
tensorflow
::
errors
::
NotFound
(
"No SequenceLinker supports channel "
,
channel
.
ShortDebugString
(),
" of ComponentSpec: "
,
component_spec
.
ShortDebugString
());
}
// Success; make modifications.
*
name
=
supporting_name
;
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
SequenceLinker
::
New
(
const
string
&
name
,
const
LinkedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
,
std
::
unique_ptr
<
SequenceLinker
>
*
linker
)
{
std
::
unique_ptr
<
SequenceLinker
>
matching_linker
;
TF_RETURN_IF_ERROR
(
SequenceLinker
::
CreateOrError
(
name
,
&
matching_linker
));
TF_RETURN_IF_ERROR
(
matching_linker
->
Initialize
(
channel
,
component_spec
));
// Success; make modifications.
*
linker
=
std
::
move
(
matching_linker
);
return
tensorflow
::
Status
::
OK
();
}
}
// namespace runtime
}
// namespace dragnn
REGISTER_SYNTAXNET_CLASS_REGISTRY
(
"DRAGNN Runtime Sequence Linker"
,
dragnn
::
runtime
::
SequenceLinker
);
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/sequence_linker.h
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_SEQUENCE_LINKER_H_
#define DRAGNN_RUNTIME_SEQUENCE_LINKER_H_
#include <stddef.h>
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/protos/spec.pb.h"
#include "syntaxnet/base.h"
#include "syntaxnet/registry.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// Interface for link extraction for sequence inputs.
//
// This can be used to avoid ComputeSession overhead in simple cases; for
// example, extracting a sequence of identity or reverse-identity links.
class
SequenceLinker
:
public
RegisterableClass
<
SequenceLinker
>
{
public:
// Sets |linker| to an instance of the subclass named |name| initialized from
// the |channel| of the |component_spec|. On error, returns non-OK and
// modifies nothing.
static
tensorflow
::
Status
New
(
const
string
&
name
,
const
LinkedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
,
std
::
unique_ptr
<
SequenceLinker
>
*
linker
);
SequenceLinker
(
const
SequenceLinker
&
)
=
delete
;
SequenceLinker
&
operator
=
(
const
SequenceLinker
&
)
=
delete
;
virtual
~
SequenceLinker
()
=
default
;
// Sets |name| to the registered name of the SequenceLinker that supports the
// |channel| of the |component_spec|. On error, returns non-OK and modifies
// nothing. The returned statuses include:
// * OK: If a supporting SequenceLinker was found.
// * INTERNAL: If an error occurred while searching for a compatible match.
// * NOT_FOUND: If the search was error-free, but no compatible match was
// found.
static
tensorflow
::
Status
Select
(
const
LinkedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
,
string
*
name
);
// Overwrites |links| with the sequence of translated link step indices for
// the |input|. Specifically, sets links[i] to the (possibly out-of-bounds)
// step index to fetch from the source component for the i'th element of the
// target sequence. Assumes that |source_num_steps| is the number of steps
// taken by the source component. On error, returns non-OK.
virtual
tensorflow
::
Status
GetLinks
(
size_t
source_num_steps
,
InputBatchCache
*
input
,
std
::
vector
<
int32
>
*
links
)
const
=
0
;
protected:
SequenceLinker
()
=
default
;
private:
// Helps prevent use of the Create() method; use New() instead.
using
RegisterableClass
<
SequenceLinker
>::
Create
;
// Returns true if this supports the |channel| of the |component_spec|.
// Implementations must coordinate to ensure that at most one supports any
// given |component_spec|.
virtual
bool
Supports
(
const
LinkedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
)
const
=
0
;
// Initializes this from the |channel| of the |component_spec|. On error,
// returns non-OK.
virtual
tensorflow
::
Status
Initialize
(
const
LinkedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
)
=
0
;
};
}
// namespace runtime
}
// namespace dragnn
DECLARE_SYNTAXNET_CLASS_REGISTRY
(
"DRAGNN Runtime Sequence Linker"
,
dragnn
::
runtime
::
SequenceLinker
);
}
// namespace syntaxnet
#define DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER(subclass) \
REGISTER_SYNTAXNET_CLASS_COMPONENT( \
::syntaxnet::dragnn::runtime::SequenceLinker, #subclass, subclass)
#endif // DRAGNN_RUNTIME_SEQUENCE_LINKER_H_
research/syntaxnet/dragnn/runtime/sequence_linker_test.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_linker.h"
#include <stddef.h>
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Supports components named "success" and initializes successfully.
class
Success
:
public
SequenceLinker
{
public:
// Implements SequenceLinker.
bool
Supports
(
const
LinkedFeatureChannel
&
,
const
ComponentSpec
&
component_spec
)
const
override
{
return
component_spec
.
name
()
==
"success"
;
}
tensorflow
::
Status
Initialize
(
const
LinkedFeatureChannel
&
,
const
ComponentSpec
&
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
GetLinks
(
size_t
,
InputBatchCache
*
,
std
::
vector
<
int32
>
*
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER
(
Success
);
// Supports components named "failure" and fails to initialize.
class
Failure
:
public
SequenceLinker
{
public:
// Implements SequenceLinker.
bool
Supports
(
const
LinkedFeatureChannel
&
,
const
ComponentSpec
&
component_spec
)
const
override
{
return
component_spec
.
name
()
==
"failure"
;
}
tensorflow
::
Status
Initialize
(
const
LinkedFeatureChannel
&
,
const
ComponentSpec
&
)
override
{
return
tensorflow
::
errors
::
Internal
(
"Boom!"
);
}
tensorflow
::
Status
GetLinks
(
size_t
,
InputBatchCache
*
,
std
::
vector
<
int32
>
*
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER
(
Failure
);
// Supports components named "duplicate" and initializes successfully.
class
Duplicate
:
public
SequenceLinker
{
public:
// Implements SequenceLinker.
bool
Supports
(
const
LinkedFeatureChannel
&
,
const
ComponentSpec
&
component_spec
)
const
override
{
return
component_spec
.
name
()
==
"duplicate"
;
}
tensorflow
::
Status
Initialize
(
const
LinkedFeatureChannel
&
,
const
ComponentSpec
&
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
GetLinks
(
size_t
,
InputBatchCache
*
,
std
::
vector
<
int32
>
*
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER
(
Duplicate
);
// Duplicate of the above.
using
Duplicate2
=
Duplicate
;
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER
(
Duplicate2
);
// Tests that a component can be successfully created.
TEST
(
SequenceLinkerTest
,
Success
)
{
string
name
;
std
::
unique_ptr
<
SequenceLinker
>
linker
;
ComponentSpec
component_spec
;
component_spec
.
set_name
(
"success"
);
TF_ASSERT_OK
(
SequenceLinker
::
Select
({},
component_spec
,
&
name
));
ASSERT_EQ
(
name
,
"Success"
);
TF_EXPECT_OK
(
SequenceLinker
::
New
(
name
,
{},
component_spec
,
&
linker
));
EXPECT_NE
(
linker
,
nullptr
);
}
// Tests that errors in Initialize() are reported.
TEST
(
SequenceLinkerTest
,
FailToInitialize
)
{
string
name
;
std
::
unique_ptr
<
SequenceLinker
>
linker
;
ComponentSpec
component_spec
;
component_spec
.
set_name
(
"failure"
);
TF_ASSERT_OK
(
SequenceLinker
::
Select
({},
component_spec
,
&
name
));
EXPECT_EQ
(
name
,
"Failure"
);
EXPECT_THAT
(
SequenceLinker
::
New
(
name
,
{},
component_spec
,
&
linker
),
test
::
IsErrorWithSubstr
(
"Boom!"
));
EXPECT_EQ
(
linker
,
nullptr
);
}
// Tests that unsupported specs are reported as NOT_FOUND errors.
TEST
(
SequenceLinkerTest
,
UnsupportedSpec
)
{
string
name
=
"not overwritten"
;
ComponentSpec
component_spec
;
component_spec
.
set_name
(
"unsupported"
);
EXPECT_THAT
(
SequenceLinker
::
Select
({},
component_spec
,
&
name
),
test
::
IsErrorWithCodeAndSubstr
(
tensorflow
::
error
::
NOT_FOUND
,
"No SequenceLinker supports channel"
));
EXPECT_EQ
(
name
,
"not overwritten"
);
}
// Tests that unsupported subclass names are reported as errors.
TEST
(
SequenceLinkerTest
,
UnsupportedSubclass
)
{
std
::
unique_ptr
<
SequenceLinker
>
linker
;
ComponentSpec
component_spec
;
EXPECT_THAT
(
SequenceLinker
::
New
(
"Unsupported"
,
{},
component_spec
,
&
linker
),
test
::
IsErrorWithSubstr
(
"Unknown DRAGNN Runtime Sequence Linker"
));
EXPECT_EQ
(
linker
,
nullptr
);
}
// Tests that multiple supporting linkers are reported as INTERNAL errors.
TEST
(
SequenceLinkerTest
,
Duplicate
)
{
string
name
=
"not overwritten"
;
ComponentSpec
component_spec
;
component_spec
.
set_name
(
"duplicate"
);
EXPECT_THAT
(
SequenceLinker
::
Select
({},
component_spec
,
&
name
),
test
::
IsErrorWithCodeAndSubstr
(
tensorflow
::
error
::
INTERNAL
,
"Multiple SequenceLinkers support channel"
));
EXPECT_EQ
(
name
,
"not overwritten"
);
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/sequence_links.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_links.h"
#include <utility>
#include "tensorflow/core/lib/core/errors.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
tensorflow
::
Status
SequenceLinkManager
::
Reset
(
const
LinkedEmbeddingManager
*
linked_embedding_manager
,
const
ComponentSpec
&
component_spec
,
const
std
::
vector
<
string
>
&
sequence_linker_types
)
{
const
size_t
num_channels
=
linked_embedding_manager
->
channel_configs_
.
size
();
if
(
component_spec
.
linked_feature_size
()
!=
num_channels
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Channel mismatch between LinkedEmbeddingManager ("
,
num_channels
,
") and ComponentSpec ("
,
component_spec
.
linked_feature_size
(),
")"
);
}
if
(
sequence_linker_types
.
size
()
!=
num_channels
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Channel mismatch between LinkedEmbeddingManager ("
,
num_channels
,
") and SequenceLinkers ("
,
sequence_linker_types
.
size
(),
")"
);
}
for
(
const
LinkedFeatureChannel
&
channel
:
component_spec
.
linked_feature
())
{
if
(
channel
.
embedding_dim
()
>=
0
)
{
return
tensorflow
::
errors
::
Unimplemented
(
"Transformed linked features are not supported for channel: "
,
channel
.
ShortDebugString
());
}
}
std
::
vector
<
ChannelConfig
>
local_configs
;
// avoid modification on error
for
(
size_t
channel_id
=
0
;
channel_id
<
num_channels
;
++
channel_id
)
{
const
LinkedFeatureChannel
&
channel
=
component_spec
.
linked_feature
(
channel_id
);
local_configs
.
emplace_back
();
ChannelConfig
&
channel_config
=
local_configs
.
back
();
channel_config
.
is_recurrent
=
channel
.
source_component
()
==
component_spec
.
name
();
channel_config
.
handle
=
linked_embedding_manager
->
channel_configs_
[
channel_id
].
source_handle
;
TF_RETURN_IF_ERROR
(
SequenceLinker
::
New
(
sequence_linker_types
[
channel_id
],
component_spec
.
linked_feature
(
channel_id
),
component_spec
,
&
channel_config
.
linker
));
}
// Success; make modifications.
zeros_
=
linked_embedding_manager
->
zeros_
.
view
();
channel_configs_
=
std
::
move
(
local_configs
);
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
SequenceLinks
::
Reset
(
bool
add_steps
,
const
SequenceLinkManager
*
manager
,
NetworkStates
*
network_states
,
InputBatchCache
*
input
)
{
zeros_
=
manager
->
zeros_
;
num_channels_
=
manager
->
channel_configs_
.
size
();
num_steps_
=
0
;
bool
have_num_steps
=
false
;
// true if |num_steps_| was assigned
// Make sure |channels_| is big enough. Note that |channels_| never shrinks,
// so the Channel.links sub-vector is never deallocated.
if
(
num_channels_
>
channels_
.
size
())
channels_
.
resize
(
num_channels_
);
// Process non-recurrent links first.
for
(
int
channel_id
=
0
;
channel_id
<
num_channels_
;
++
channel_id
)
{
const
SequenceLinkManager
::
ChannelConfig
&
channel_config
=
manager
->
channel_configs_
[
channel_id
];
if
(
channel_config
.
is_recurrent
)
continue
;
Channel
&
channel
=
channels_
[
channel_id
];
channel
.
layer
=
network_states
->
GetLayer
(
channel_config
.
handle
);
TF_RETURN_IF_ERROR
(
channel_config
.
linker
->
GetLinks
(
channel
.
layer
.
num_rows
(),
input
,
&
channel
.
links
));
if
(
!
have_num_steps
)
{
num_steps_
=
channel
.
links
.
size
();
have_num_steps
=
true
;
}
else
if
(
channel
.
links
.
size
()
!=
num_steps_
)
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Inconsistent link sequence lengths at channel ID "
,
channel_id
,
": got "
,
channel
.
links
.
size
(),
" but expected "
,
num_steps_
);
}
}
// Add steps to the |network_states|, if requested.
if
(
add_steps
)
{
if
(
!
have_num_steps
)
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Cannot infer the number of steps to add because there are no "
"non-recurrent links"
);
}
network_states
->
AddSteps
(
num_steps_
);
}
// Process recurrent links. These require that the current component in the
// |network_states| has been sized to the proper number of steps.
for
(
int
channel_id
=
0
;
channel_id
<
num_channels_
;
++
channel_id
)
{
const
SequenceLinkManager
::
ChannelConfig
&
channel_config
=
manager
->
channel_configs_
[
channel_id
];
if
(
!
channel_config
.
is_recurrent
)
continue
;
Channel
&
channel
=
channels_
[
channel_id
];
channel
.
layer
=
network_states
->
GetLayer
(
channel_config
.
handle
);
TF_RETURN_IF_ERROR
(
channel_config
.
linker
->
GetLinks
(
channel
.
layer
.
num_rows
(),
input
,
&
channel
.
links
));
if
(
!
have_num_steps
)
{
num_steps_
=
channel
.
links
.
size
();
have_num_steps
=
true
;
}
else
if
(
channel
.
links
.
size
()
!=
num_steps_
)
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Inconsistent link sequence lengths at channel ID "
,
channel_id
,
": got "
,
channel
.
links
.
size
(),
" but expected "
,
num_steps_
);
}
}
return
tensorflow
::
Status
::
OK
();
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/sequence_links.h
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Utils for configuring and extracting linked embeddings for sequence-based
// models. Analogous to LinkedEmbeddingManager and LinkedEmbeddings, but uses
// SequenceLinker instead of ComputeSession.
#ifndef DRAGNN_RUNTIME_SEQUENCE_LINKS_H_
#define DRAGNN_RUNTIME_SEQUENCE_LINKS_H_
#include <stddef.h>
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/linked_embeddings.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/sequence_linker.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// Manager for linked embeddings for sequence-based models. This is a wrapper
// around the LinkedEmbeddingManager.
class
SequenceLinkManager
{
public:
// Creates an empty manager.
SequenceLinkManager
()
=
default
;
// Resets this to wrap the |linked_embedding_manager|, which must outlive
// this. The |sequence_linker_types| should name one SequenceLinker subclass
// per channel; e.g., {"IdentitySequenceLinker", "ReversedSequenceLinker"}.
// This initializes each SequenceLinker from the |component_spec|. On error,
// returns non-OK and does not modify this.
tensorflow
::
Status
Reset
(
const
LinkedEmbeddingManager
*
linked_embedding_manager
,
const
ComponentSpec
&
component_spec
,
const
std
::
vector
<
string
>
&
sequence_linker_types
);
// Accessors.
size_t
num_channels
()
const
{
return
channel_configs_
.
size
();
}
private:
friend
class
SequenceLinks
;
// Configuration for a single linked embedding channel.
struct
ChannelConfig
{
// Whether this link is recurrent.
bool
is_recurrent
=
false
;
// Handle to the source layer in the relevant NetworkStates.
LayerHandle
<
float
>
handle
;
// Extractor for sequences of translated link indices.
std
::
unique_ptr
<
SequenceLinker
>
linker
;
};
// Array of zeros that can be substituted for out-of-bounds embeddings. This
// is a reference to the corresponding array in the LinkedEmbeddingManager.
// See the large comment in linked_embeddings.cc for reference.
AlignedView
zeros_
;
// Ordered list of configurations for each channel.
std
::
vector
<
ChannelConfig
>
channel_configs_
;
};
// A set of linked embeddings for a sequence-based model. Configured by a
// SequenceLinkManager.
class
SequenceLinks
{
public:
// Creates an empty set of embeddings.
SequenceLinks
()
=
default
;
// Resets this to the sequences of linked embeddings managed by the |manager|
// on the |input|. Retrieves layers from the |network_states|. The |manager|
// must live until this is destroyed or Reset(), and should not be modified
// during that time. If |add_steps| is true, then infers the number of steps
// from the non-recurrent links and adds steps to the |network_states| before
// processing the recurrent links. On error, returns non-OK.
//
// NB: Recurrent links are tricky, because the |network_states| must be filled
// with steps before processing recurrent links. There are two approaches:
// 1. Add steps to the |network_states| before calling Reset(). This only
// works if the component also has fixed features, which can be used to
// infer the number of steps.
// 2. Set |add_steps| to true, so steps are added during Reset(). This only
// works if the component also has non-recurrent links, which can be used
// to infer the number of steps.
// If a component only has recurrent links then neither of the above works,
// but such a component would be nonsensical: it recurses on itself with no
// external input.
tensorflow
::
Status
Reset
(
bool
add_steps
,
const
SequenceLinkManager
*
manager
,
NetworkStates
*
network_states
,
InputBatchCache
*
input
);
// Retrieves the linked embedding for the |target_index|'th element of the
// |channel_id|'th channel. Sets |embedding| to the linked embedding vector
// and sets |is_out_of_bounds| to true if the link is out of bounds.
void
Get
(
size_t
channel_id
,
size_t
target_index
,
Vector
<
float
>
*
embedding
,
bool
*
is_out_of_bounds
)
const
;
// Accessors.
size_t
num_channels
()
const
{
return
num_channels_
;
}
size_t
num_steps
()
const
{
return
num_steps_
;
}
private:
// Data associated with a single linked embedding channel.
struct
Channel
{
// Source layer activations.
Matrix
<
float
>
layer
;
// Translated link indices for each step.
std
::
vector
<
int32
>
links
;
};
// Zero vector from the most recent Reset().
AlignedView
zeros_
;
// Number of channels and steps from the most recent Reset().
size_t
num_channels_
=
0
;
size_t
num_steps_
=
0
;
// Ordered list of linked embedding channels. This may contain more than
// |num_channels_| entries, to avoid deallocation/reallocation cycles, but
// only the first |num_channels_| entries are valid.
std
::
vector
<
Channel
>
channels_
;
};
// Implementation details below.
inline
void
SequenceLinks
::
Get
(
size_t
channel_id
,
size_t
target_index
,
Vector
<
float
>
*
embedding
,
bool
*
is_out_of_bounds
)
const
{
DCHECK_LT
(
channel_id
,
num_channels
());
DCHECK_LT
(
target_index
,
num_steps
());
const
Channel
&
channel
=
channels_
[
channel_id
];
const
int32
link
=
channel
.
links
[
target_index
];
*
is_out_of_bounds
=
(
link
<
0
||
link
>=
channel
.
layer
.
num_rows
());
if
(
*
is_out_of_bounds
)
{
*
embedding
=
Vector
<
float
>
(
zeros_
,
channel
.
layer
.
num_columns
());
}
else
{
*
embedding
=
channel
.
layer
.
row
(
link
);
}
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_SEQUENCE_LINKS_H_
research/syntaxnet/dragnn/runtime/sequence_links_test.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_links.h"
#include <stddef.h>
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/linked_embeddings.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/sequence_linker.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Dimensions of the layers in the network (see ResetManager() below).
const
size_t
kPrevious1LayerDim
=
16
;
const
size_t
kPrevious2LayerDim
=
32
;
const
size_t
kRecurrentLayerDim
=
48
;
// Number of transition steps to take in each component in the network.
const
size_t
kNumSteps
=
10
;
// A working one-channel ComponentSpec.
const
char
kSingleSpec
[]
=
R"(linked_feature {
embedding_dim: -1
source_component: 'source_component_1'
source_layer: 'previous_1'
size: 1
})"
;
// A working multi-channel ComponentSpec.
const
char
kMultiSpec
[]
=
R"(linked_feature {
embedding_dim: -1
source_component: 'source_component_1'
source_layer: 'previous_1'
size: 1
}
linked_feature {
embedding_dim: -1
source_component: 'source_component_2'
source_layer: 'previous_2'
size: 1
}
linked_feature {
embedding_dim: -1
source_component: 'test_component'
source_layer: 'recurrent'
size: 1
})"
;
// A recurrent-only ComponentSpec.
const
char
kRecurrentSpec
[]
=
R"(linked_feature {
embedding_dim: -1
source_component: 'test_component'
source_layer: 'recurrent'
size: 1
})"
;
// Fails to initialize.
class
FailToInitialize
:
public
SequenceLinker
{
public:
// Implements SequenceLinker.
bool
Supports
(
const
LinkedFeatureChannel
&
,
const
ComponentSpec
&
component_spec
)
const
override
{
LOG
(
FATAL
)
<<
"Should never be called."
;
}
tensorflow
::
Status
Initialize
(
const
LinkedFeatureChannel
&
,
const
ComponentSpec
&
)
override
{
return
tensorflow
::
errors
::
Internal
(
"No initialization for you!"
);
}
tensorflow
::
Status
GetLinks
(
size_t
,
InputBatchCache
*
,
std
::
vector
<
int32
>
*
)
const
override
{
LOG
(
FATAL
)
<<
"Should never be called."
;
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER
(
FailToInitialize
);
// Initializes OK, then fails to extract links.
class
FailToGetLinks
:
public
FailToInitialize
{
public:
// Implements SequenceLinker.
tensorflow
::
Status
Initialize
(
const
LinkedFeatureChannel
&
,
const
ComponentSpec
&
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
GetLinks
(
size_t
,
InputBatchCache
*
,
std
::
vector
<
int32
>
*
)
const
override
{
return
tensorflow
::
errors
::
Internal
(
"No links for you!"
);
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER
(
FailToGetLinks
);
// Initializes OK and links to the previous step.
class
LinkToPrevious
:
public
FailToGetLinks
{
public:
// Implements SequenceLinker.
tensorflow
::
Status
GetLinks
(
size_t
source_num_steps
,
InputBatchCache
*
,
std
::
vector
<
int32
>
*
links
)
const
override
{
links
->
resize
(
source_num_steps
);
for
(
int
i
=
0
;
i
<
links
->
size
();
++
i
)
(
*
links
)[
i
]
=
i
-
1
;
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER
(
LinkToPrevious
);
// Initializes OK but produces the wrong number of links.
class
WrongNumberOfLinks
:
public
FailToGetLinks
{
public:
// Implements SequenceLinker.
tensorflow
::
Status
GetLinks
(
size_t
,
InputBatchCache
*
,
std
::
vector
<
int32
>
*
links
)
const
override
{
links
->
resize
(
kNumSteps
+
1
);
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER
(
WrongNumberOfLinks
);
class
SequenceLinkManagerTest
:
public
NetworkTestBase
{
protected:
// Sets up previous components and layers.
void
AddComponentsAndLayers
()
{
AddComponent
(
"source_component_0"
);
AddComponent
(
"source_component_1"
);
AddLayer
(
"previous_1"
,
kPrevious1LayerDim
);
AddComponent
(
"source_component_2"
);
AddLayer
(
"previous_2"
,
kPrevious2LayerDim
);
AddComponent
(
kTestComponentName
);
AddLayer
(
"recurrent"
,
kRecurrentLayerDim
);
}
// Creates a SequenceLinkManager and returns the result of Reset()-ing it
// using the |component_spec_text|.
tensorflow
::
Status
ResetManager
(
const
string
&
component_spec_text
,
const
std
::
vector
<
string
>
&
sequence_linker_types
)
{
ComponentSpec
component_spec
;
CHECK
(
TextFormat
::
ParseFromString
(
component_spec_text
,
&
component_spec
));
component_spec
.
set_name
(
kTestComponentName
);
AddComponentsAndLayers
();
TF_RETURN_IF_ERROR
(
linked_embedding_manager_
.
Reset
(
component_spec
,
&
variable_store_
,
&
network_state_manager_
));
return
manager_
.
Reset
(
&
linked_embedding_manager_
,
component_spec
,
sequence_linker_types
);
}
LinkedEmbeddingManager
linked_embedding_manager_
;
SequenceLinkManager
manager_
;
};
// Tests that SequenceLinkManager is empty by default.
TEST_F
(
SequenceLinkManagerTest
,
EmptyByDefault
)
{
EXPECT_EQ
(
manager_
.
num_channels
(),
0
);
}
// Tests that SequenceLinkManager is empty when reset to an empty spec.
TEST_F
(
SequenceLinkManagerTest
,
EmptySpec
)
{
TF_EXPECT_OK
(
ResetManager
(
""
,
{}));
EXPECT_EQ
(
manager_
.
num_channels
(),
0
);
}
// Tests that SequenceLinkManager works with a single channel.
TEST_F
(
SequenceLinkManagerTest
,
OneChannel
)
{
TF_EXPECT_OK
(
ResetManager
(
kSingleSpec
,
{
"LinkToPrevious"
}));
EXPECT_EQ
(
manager_
.
num_channels
(),
1
);
}
// Tests that SequenceLinkManager works with multiple channels.
TEST_F
(
SequenceLinkManagerTest
,
MultipleChannels
)
{
TF_EXPECT_OK
(
ResetManager
(
kMultiSpec
,
{
"LinkToPrevious"
,
"LinkToPrevious"
,
"LinkToPrevious"
}));
EXPECT_EQ
(
manager_
.
num_channels
(),
3
);
}
// Tests that SequenceLinkManager fails if the LinkedEmbeddingManager and
// ComponentSpec are mismatched.
TEST_F
(
SequenceLinkManagerTest
,
MismatchedLinkedManagerAndComponentSpec
)
{
ComponentSpec
component_spec
;
CHECK
(
TextFormat
::
ParseFromString
(
kMultiSpec
,
&
component_spec
));
component_spec
.
set_name
(
kTestComponentName
);
AddComponentsAndLayers
();
TF_ASSERT_OK
(
linked_embedding_manager_
.
Reset
(
component_spec
,
&
variable_store_
,
&
network_state_manager_
));
// Remove one linked feature, resulting in a mismatch.
component_spec
.
mutable_linked_feature
()
->
RemoveLast
();
EXPECT_THAT
(
manager_
.
Reset
(
&
linked_embedding_manager_
,
component_spec
,
{
"LinkToPrevious"
,
"LinkToPrevious"
,
"LinkToPrevious"
}),
test
::
IsErrorWithSubstr
(
"Channel mismatch between LinkedEmbeddingManager "
"(3) and ComponentSpec (2)"
));
}
// Tests that SequenceLinkManager fails if the LinkedEmbeddingManager and
// SequenceLinkers are mismatched.
TEST_F
(
SequenceLinkManagerTest
,
MismatchedLinkedManagerAndSequenceLinkers
)
{
EXPECT_THAT
(
ResetManager
(
kMultiSpec
,
{
"LinkToPrevious"
,
"LinkToPrevious"
}),
test
::
IsErrorWithSubstr
(
"Channel mismatch between LinkedEmbeddingManager "
"(3) and SequenceLinkers (2)"
));
}
// Tests that SequenceLinkManager fails when the link is transformed.
TEST_F
(
SequenceLinkManagerTest
,
UnsupportedTransformedLink
)
{
const
string
kBadSpec
=
R"(linked_feature {
embedding_dim: 16 # bad
source_component: 'source_component_1'
source_layer: 'previous_1'
size: 1
})"
;
AddLinkedWeightMatrix
(
0
,
kPrevious1LayerDim
,
16
,
0.0
);
AddLinkedOutOfBoundsVector
(
0
,
16
,
0.0
);
EXPECT_THAT
(
ResetManager
(
kBadSpec
,
{
"LinkToPrevious"
}),
test
::
IsErrorWithSubstr
(
"Transformed linked features are not supported"
));
}
// Tests that SequenceLinkManager fails if one of the SequenceLinkers fails to
// initialize.
TEST_F
(
SequenceLinkManagerTest
,
FailToInitializeSequenceLinker
)
{
EXPECT_THAT
(
ResetManager
(
kMultiSpec
,
{
"LinkToPrevious"
,
"FailToInitialize"
,
"LinkToPrevious"
}),
test
::
IsErrorWithSubstr
(
"No initialization for you!"
));
}
// Tests that SequenceLinkManager is OK even if the SequenceLinkers would fail
// in GetLinks().
TEST_F
(
SequenceLinkManagerTest
,
ManagerDoesntCareAboutGetLinks
)
{
TF_EXPECT_OK
(
ResetManager
(
kMultiSpec
,
{
"FailToGetLinks"
,
"FailToGetLinks"
,
"FailToGetLinks"
}));
}
// Values to fill each layer with.
const
float
kPrevious1LayerValue
=
1.0
;
const
float
kPrevious2LayerValue
=
2.0
;
const
float
kRecurrentLayerValue
=
3.0
;
class
SequenceLinksTest
:
public
SequenceLinkManagerTest
{
protected:
// Resets the |sequence_links_| using the |manager_|, |network_states_|, and
// |input_batch_cache_|, and returns the resulting status. Passes |add_steps|
// to Reset() and advances the current component by |num_steps|.
tensorflow
::
Status
ResetLinks
(
bool
add_steps
=
false
,
size_t
num_steps
=
kNumSteps
)
{
network_states_
.
Reset
(
&
network_state_manager_
);
// Fill components with steps.
StartComponent
(
kNumSteps
);
// source_component_0
StartComponent
(
kNumSteps
);
// source_component_1
StartComponent
(
kNumSteps
);
// source_component_2
StartComponent
(
num_steps
);
// current component
// Fill layers with values.
FillLayer
(
"source_component_1"
,
"previous_1"
,
kPrevious1LayerValue
);
FillLayer
(
"source_component_2"
,
"previous_2"
,
kPrevious2LayerValue
);
FillLayer
(
kTestComponentName
,
"recurrent"
,
kRecurrentLayerValue
);
return
sequence_links_
.
Reset
(
add_steps
,
&
manager_
,
&
network_states_
,
&
input_batch_cache_
);
}
InputBatchCache
input_batch_cache_
;
SequenceLinks
sequence_links_
;
};
// Tests that SequenceLinks is empty by default.
TEST_F
(
SequenceLinksTest
,
EmptyByDefault
)
{
EXPECT_EQ
(
sequence_links_
.
num_channels
(),
0
);
EXPECT_EQ
(
sequence_links_
.
num_steps
(),
0
);
}
// Tests that SequenceLinks is empty when reset by an empty manager.
TEST_F
(
SequenceLinksTest
,
EmptyManager
)
{
TF_ASSERT_OK
(
ResetManager
(
""
,
{}));
TF_EXPECT_OK
(
ResetLinks
());
EXPECT_EQ
(
sequence_links_
.
num_channels
(),
0
);
EXPECT_EQ
(
sequence_links_
.
num_steps
(),
0
);
}
// Tests that SequenceLinks fails when one of the non-recurrent SequenceLinkers
// fails.
TEST_F
(
SequenceLinksTest
,
FailToGetNonRecurrentLinks
)
{
TF_ASSERT_OK
(
ResetManager
(
kMultiSpec
,
{
"LinkToPrevious"
,
"FailToGetLinks"
,
"LinkToPrevious"
}));
EXPECT_THAT
(
ResetLinks
(),
test
::
IsErrorWithSubstr
(
"No links for you!"
));
}
// Tests that SequenceLinks fails when one of the recurrent SequenceLinkers
// fails.
TEST_F
(
SequenceLinksTest
,
FailToGetRecurrentLinks
)
{
TF_ASSERT_OK
(
ResetManager
(
kMultiSpec
,
{
"LinkToPrevious"
,
"LinkToPrevious"
,
"FailToGetLinks"
}));
EXPECT_THAT
(
ResetLinks
(),
test
::
IsErrorWithSubstr
(
"No links for you!"
));
}
// Tests that SequenceLinks fails when the non-recurrent SequenceLinkers produce
// different numbers of links.
TEST_F
(
SequenceLinksTest
,
MismatchedNumbersOfNonRecurrentLinks
)
{
TF_ASSERT_OK
(
ResetManager
(
kMultiSpec
,
{
"LinkToPrevious"
,
"WrongNumberOfLinks"
,
"LinkToPrevious"
}));
EXPECT_THAT
(
ResetLinks
(),
test
::
IsErrorWithSubstr
(
"Inconsistent link sequence lengths at "
"channel ID 1: got 11 but expected 10"
));
}
// Tests that SequenceLinks fails when the recurrent SequenceLinkers produce
// different numbers of links.
TEST_F
(
SequenceLinksTest
,
MismatchedNumbersOfRecurrentLinks
)
{
TF_ASSERT_OK
(
ResetManager
(
kMultiSpec
,
{
"LinkToPrevious"
,
"LinkToPrevious"
,
"WrongNumberOfLinks"
}));
EXPECT_THAT
(
ResetLinks
(),
test
::
IsErrorWithSubstr
(
"Inconsistent link sequence lengths at "
"channel ID 2: got 11 but expected 10"
));
}
// Tests that SequenceLinks works as expected on one channel.
TEST_F
(
SequenceLinksTest
,
SingleChannel
)
{
TF_ASSERT_OK
(
ResetManager
(
kSingleSpec
,
{
"LinkToPrevious"
}));
TF_ASSERT_OK
(
ResetLinks
());
ASSERT_EQ
(
sequence_links_
.
num_channels
(),
1
);
ASSERT_EQ
(
sequence_links_
.
num_steps
(),
kNumSteps
);
const
Matrix
<
float
>
previous1
(
GetLayer
(
"source_component_1"
,
"previous_1"
));
Vector
<
float
>
embedding
;
bool
is_out_of_bounds
=
false
;
// LinkToPrevious links the 0'th index to -1, which is out of bounds.
sequence_links_
.
Get
(
0
,
0
,
&
embedding
,
&
is_out_of_bounds
);
EXPECT_TRUE
(
is_out_of_bounds
);
ExpectVector
(
embedding
,
kPrevious1LayerDim
,
0.0
);
// The remaining links point to the previous item.
for
(
int
i
=
1
;
i
<
kNumSteps
;
++
i
)
{
sequence_links_
.
Get
(
0
,
i
,
&
embedding
,
&
is_out_of_bounds
);
EXPECT_FALSE
(
is_out_of_bounds
);
ExpectVector
(
embedding
,
kPrevious1LayerDim
,
kPrevious1LayerValue
);
EXPECT_EQ
(
embedding
.
data
(),
previous1
.
row
(
i
-
1
).
data
());
}
}
// Tests that SequenceLinks works as expected on multiple channels.
TEST_F
(
SequenceLinksTest
,
ManyChannels
)
{
TF_ASSERT_OK
(
ResetManager
(
kMultiSpec
,
{
"LinkToPrevious"
,
"LinkToPrevious"
,
"LinkToPrevious"
}));
TF_ASSERT_OK
(
ResetLinks
());
ASSERT_EQ
(
sequence_links_
.
num_channels
(),
3
);
ASSERT_EQ
(
sequence_links_
.
num_steps
(),
kNumSteps
);
const
Matrix
<
float
>
previous1
(
GetLayer
(
"source_component_1"
,
"previous_1"
));
const
Matrix
<
float
>
previous2
(
GetLayer
(
"source_component_2"
,
"previous_2"
));
const
Matrix
<
float
>
recurrent
(
GetLayer
(
kTestComponentName
,
"recurrent"
));
Vector
<
float
>
embedding
;
bool
is_out_of_bounds
=
false
;
// LinkToPrevious links the 0'th index to -1, which is out of bounds.
sequence_links_
.
Get
(
0
,
0
,
&
embedding
,
&
is_out_of_bounds
);
EXPECT_TRUE
(
is_out_of_bounds
);
ExpectVector
(
embedding
,
kPrevious1LayerDim
,
0.0
);
sequence_links_
.
Get
(
1
,
0
,
&
embedding
,
&
is_out_of_bounds
);
EXPECT_TRUE
(
is_out_of_bounds
);
ExpectVector
(
embedding
,
kPrevious2LayerDim
,
0.0
);
sequence_links_
.
Get
(
2
,
0
,
&
embedding
,
&
is_out_of_bounds
);
EXPECT_TRUE
(
is_out_of_bounds
);
ExpectVector
(
embedding
,
kRecurrentLayerDim
,
0.0
);
// The remaining links point to the previous item.
for
(
int
i
=
1
;
i
<
kNumSteps
;
++
i
)
{
sequence_links_
.
Get
(
0
,
i
,
&
embedding
,
&
is_out_of_bounds
);
EXPECT_FALSE
(
is_out_of_bounds
);
ExpectVector
(
embedding
,
kPrevious1LayerDim
,
kPrevious1LayerValue
);
EXPECT_EQ
(
embedding
.
data
(),
previous1
.
row
(
i
-
1
).
data
());
sequence_links_
.
Get
(
1
,
i
,
&
embedding
,
&
is_out_of_bounds
);
EXPECT_FALSE
(
is_out_of_bounds
);
ExpectVector
(
embedding
,
kPrevious2LayerDim
,
kPrevious2LayerValue
);
EXPECT_EQ
(
embedding
.
data
(),
previous2
.
row
(
i
-
1
).
data
());
sequence_links_
.
Get
(
2
,
i
,
&
embedding
,
&
is_out_of_bounds
);
EXPECT_FALSE
(
is_out_of_bounds
);
ExpectVector
(
embedding
,
kRecurrentLayerDim
,
kRecurrentLayerValue
);
EXPECT_EQ
(
embedding
.
data
(),
recurrent
.
row
(
i
-
1
).
data
());
}
}
// Tests that SequenceLinks is emptied when resetting to an empty manager after
// being reset to a non-empty manager.
TEST_F
(
SequenceLinksTest
,
ResetToEmptyAfterNonEmpty
)
{
TF_ASSERT_OK
(
ResetManager
(
kSingleSpec
,
{
"LinkToPrevious"
}));
TF_ASSERT_OK
(
ResetLinks
());
ASSERT_EQ
(
sequence_links_
.
num_channels
(),
1
);
ASSERT_EQ
(
sequence_links_
.
num_steps
(),
kNumSteps
);
SequenceLinkManager
manager
;
TF_ASSERT_OK
(
sequence_links_
.
Reset
(
/*add_steps=*/
false
,
&
manager
,
&
network_states_
,
&
input_batch_cache_
));
ASSERT_EQ
(
sequence_links_
.
num_channels
(),
0
);
ASSERT_EQ
(
sequence_links_
.
num_steps
(),
0
);
}
// Tests that SequenceLinks fails when adding steps to a component with no
// non-recurrent links.
TEST_F
(
SequenceLinksTest
,
AddStepsWithNoNonRecurrentLinks
)
{
TF_ASSERT_OK
(
ResetManager
(
kRecurrentSpec
,
{
"LinkToPrevious"
}));
EXPECT_THAT
(
ResetLinks
(
/*add_steps=*/
true
),
test
::
IsErrorWithSubstr
(
"Cannot infer the number of steps to add because "
"there are no non-recurrent links"
));
}
// Tests that SequenceLinks produces no links when processing a component with
// only recurrent links, and when the NetworkStates has no steps.
TEST_F
(
SequenceLinksTest
,
RecurrentLinksWithNoSteps
)
{
TF_ASSERT_OK
(
ResetManager
(
kRecurrentSpec
,
{
"LinkToPrevious"
}));
TF_ASSERT_OK
(
ResetLinks
(
/*add_steps=*/
false
,
/*num_steps=*/
0
));
ASSERT_EQ
(
sequence_links_
.
num_channels
(),
1
);
ASSERT_EQ
(
sequence_links_
.
num_steps
(),
0
);
}
// Tests that SequenceLinks properly infers the number of steps and adds them
// when processing a component with both non-recurrent and recurrent links.
TEST_F
(
SequenceLinksTest
,
AddStepsWithNonRecurrentAndRecurrentLinks
)
{
TF_ASSERT_OK
(
ResetManager
(
kMultiSpec
,
{
"LinkToPrevious"
,
"LinkToPrevious"
,
"LinkToPrevious"
}));
TF_ASSERT_OK
(
ResetLinks
(
/*add_steps=*/
true
,
/*num_steps=*/
0
));
ASSERT_EQ
(
sequence_links_
.
num_channels
(),
3
);
ASSERT_EQ
(
sequence_links_
.
num_steps
(),
kNumSteps
);
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/sequence_model.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_model.h"
#include <vector>
#include "dragnn/runtime/attributes.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/sequence_backend.h"
#include "dragnn/runtime/transition_system_traits.h"
#include "tensorflow/core/lib/core/errors.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Proper backend for sequence-based models.
constexpr
char
kSupportedBackend
[]
=
"SequenceBackend"
;
// Attributes for sequence-based comopnents, attached to the component builder.
// See SequenceComponentTransformer.
struct
ComponentBuilderAttributes
:
public
Attributes
{
// Registered names of the sequence extractors to use.
Mandatory
<
std
::
vector
<
string
>>
sequence_extractors
{
"sequence_extractors"
,
this
};
// Registered names of the sequence linkers to use per channel, in order.
Mandatory
<
std
::
vector
<
string
>>
sequence_linkers
{
"sequence_linkers"
,
this
};
// Registered name of the sequence predictor to use.
Mandatory
<
string
>
sequence_predictor
{
"sequence_predictor"
,
this
};
};
}
// namespace
bool
SequenceModel
::
Supports
(
const
ComponentSpec
&
component_spec
)
{
// Require single-embedding fixed and linked features.
for
(
const
FixedFeatureChannel
&
channel
:
component_spec
.
fixed_feature
())
{
if
(
channel
.
size
()
!=
1
)
return
false
;
}
for
(
const
LinkedFeatureChannel
&
channel
:
component_spec
.
linked_feature
())
{
if
(
channel
.
size
()
!=
1
)
return
false
;
}
const
bool
has_fixed_feature
=
component_spec
.
fixed_feature_size
()
>
0
;
bool
has_recurrent_link
=
false
;
bool
has_non_recurrent_link
=
false
;
for
(
const
LinkedFeatureChannel
&
channel
:
component_spec
.
linked_feature
())
{
if
(
channel
.
source_component
()
==
component_spec
.
name
())
{
has_recurrent_link
=
true
;
}
else
{
has_non_recurrent_link
=
true
;
}
}
// Recurrent links must be accompanied by fixed features or non-recurrent
// links, so the number of recurrent steps can be pre-computed.
if
(
has_recurrent_link
&&
!
has_fixed_feature
&&
!
has_non_recurrent_link
)
{
return
false
;
}
const
int
num_features
=
component_spec
.
fixed_feature_size
()
+
component_spec
.
linked_feature_size
();
return
component_spec
.
backend
().
registered_name
()
==
kSupportedBackend
&&
num_features
>
0
;
}
tensorflow
::
Status
SequenceModel
::
Initialize
(
const
ComponentSpec
&
component_spec
,
const
string
&
logits_name
,
const
FixedEmbeddingManager
*
fixed_embedding_manager
,
const
LinkedEmbeddingManager
*
linked_embedding_manager
,
NetworkStateManager
*
network_state_manager
)
{
component_name_
=
component_spec
.
name
();
if
(
component_spec
.
backend
().
registered_name
()
!=
kSupportedBackend
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Invalid component backend: "
,
component_spec
.
backend
().
registered_name
());
}
TransitionSystemTraits
traits
(
component_spec
);
deterministic_
=
traits
.
is_deterministic
;
left_to_right_
=
traits
.
is_left_to_right
;
ComponentBuilderAttributes
component_builder_attributes
;
TF_RETURN_IF_ERROR
(
component_builder_attributes
.
Reset
(
component_spec
.
component_builder
().
parameters
()));
TF_RETURN_IF_ERROR
(
sequence_feature_manager_
.
Reset
(
fixed_embedding_manager
,
component_spec
,
component_builder_attributes
.
sequence_extractors
()));
TF_RETURN_IF_ERROR
(
sequence_link_manager_
.
Reset
(
linked_embedding_manager
,
component_spec
,
component_builder_attributes
.
sequence_linkers
()));
have_fixed_features_
=
sequence_feature_manager_
.
num_channels
()
>
0
;
have_linked_features_
=
sequence_link_manager_
.
num_channels
()
>
0
;
if
(
!
have_fixed_features_
&&
!
have_linked_features_
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"No fixed or linked features"
);
}
if
(
!
deterministic_
)
{
size_t
dimension
=
0
;
TF_RETURN_IF_ERROR
(
network_state_manager
->
LookupLayer
(
component_name_
,
logits_name
,
&
dimension
,
&
logits_handle_
));
if
(
dimension
!=
component_spec
.
num_actions
())
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Logits dimension mismatch between NetworkStates ("
,
dimension
,
") and ComponentSpec ("
,
component_spec
.
num_actions
(),
")"
);
}
TF_RETURN_IF_ERROR
(
SequencePredictor
::
New
(
component_builder_attributes
.
sequence_predictor
(),
component_spec
,
&
sequence_predictor_
));
}
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
SequenceModel
::
Preprocess
(
SessionState
*
session_state
,
ComputeSession
*
compute_session
,
EvaluateState
*
evaluate_state
)
const
{
InputBatchCache
*
input_batch_cache
=
compute_session
->
GetInputBatchCache
();
if
(
input_batch_cache
==
nullptr
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Null input batch"
);
}
// The feature handling below is complicated by the need to support recurrent
// links. See the comment on SequenceLinks::Reset().
NetworkStates
&
network_states
=
session_state
->
network_states
;
TF_RETURN_IF_ERROR
(
evaluate_state
->
features
.
Reset
(
&
sequence_feature_manager_
,
input_batch_cache
));
if
(
have_fixed_features_
)
{
network_states
.
AddSteps
(
evaluate_state
->
features
.
num_steps
());
}
TF_RETURN_IF_ERROR
(
evaluate_state
->
links
.
Reset
(
/*add_steps=*/
!
have_fixed_features_
,
&
sequence_link_manager_
,
&
network_states
,
input_batch_cache
));
// Initialize() ensures that there is at least one fixed or linked feature;
// use it to determine the number of steps.
size_t
num_steps
=
0
;
if
(
have_fixed_features_
&&
have_linked_features_
)
{
num_steps
=
evaluate_state
->
features
.
num_steps
();
if
(
num_steps
!=
evaluate_state
->
links
.
num_steps
())
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Sequence length mismatch between fixed features ("
,
num_steps
,
") and linked features ("
,
evaluate_state
->
links
.
num_steps
(),
")"
);
}
}
else
if
(
have_fixed_features_
)
{
num_steps
=
evaluate_state
->
features
.
num_steps
();
}
else
{
num_steps
=
evaluate_state
->
links
.
num_steps
();
}
// Tell the backend the current input size, so it can handle requests for
// linked features from downstream components.
static_cast
<
SequenceBackend
*>
(
compute_session
->
GetReadiedComponent
(
component_name_
))
->
SetSequenceSize
(
num_steps
);
evaluate_state
->
num_steps
=
num_steps
;
evaluate_state
->
input
=
input_batch_cache
;
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
SequenceModel
::
Predict
(
const
NetworkStates
&
network_states
,
EvaluateState
*
evaluate_state
)
const
{
if
(
!
deterministic_
)
{
const
Matrix
<
float
>
logits
(
network_states
.
GetLayer
(
logits_handle_
));
TF_RETURN_IF_ERROR
(
sequence_predictor_
->
Predict
(
logits
,
evaluate_state
->
input
));
}
return
tensorflow
::
Status
::
OK
();
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/sequence_model.h
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_SEQUENCE_MODEL_H_
#define DRAGNN_RUNTIME_SEQUENCE_MODEL_H_
#include <stddef.h>
#include <memory>
#include <string>
#include "dragnn/core/compute_session.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/fixed_embeddings.h"
#include "dragnn/runtime/linked_embeddings.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/sequence_features.h"
#include "dragnn/runtime/sequence_links.h"
#include "dragnn/runtime/sequence_predictor.h"
#include "dragnn/runtime/session_state.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// A class that configures and helps evaluate a sequence-based model.
//
// This class requires the SequenceBackend component backend and elides most of
// the ComputeSession feature extraction and transition system overhead.
class
SequenceModel
{
public:
// State associated with a single evaluation of the model.
struct
EvaluateState
{
// Number of transition steps in the current sequence.
size_t
num_steps
=
0
;
// Current input batch.
InputBatchCache
*
input
=
nullptr
;
// Sequence-based fixed features.
SequenceFeatures
features
;
// Sequence-based linked embeddings.
SequenceLinks
links
;
};
// Creates an uninitialized model. Call Initialize() before use.
SequenceModel
()
=
default
;
// Returns true if the |component_spec| is compatible with a sequence model.
static
bool
Supports
(
const
ComponentSpec
&
component_spec
);
// Initalizes this from the configuration in the |component_spec|. Wraps the
// |fixed_embedding_manager| and |linked_embedding_manager| in sequence-based
// versions, and requests layers from the |network_state_manager|. All of the
// managers must outlive this. If the transition system is non-deterministic,
// uses the layer named |logits_name| to make predictions later in Predict();
// otherwise, |logits_name| is ignored and Predict() does nothing. On error,
// returns non-OK.
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
component_spec
,
const
string
&
logits_name
,
const
FixedEmbeddingManager
*
fixed_embedding_manager
,
const
LinkedEmbeddingManager
*
linked_embedding_manager
,
NetworkStateManager
*
network_state_manager
);
// Resets the |evaluate_state| to values derived from the |session_state| and
// |compute_session|. Also updates the NetworkStates in the |session_state|
// and the current component of the |compute_session| with the length of the
// current sequence. Call this before producing output layers. On error,
// returns non-OK.
tensorflow
::
Status
Preprocess
(
SessionState
*
session_state
,
ComputeSession
*
compute_session
,
EvaluateState
*
evaluate_state
)
const
;
// If applicable, makes predictions based on the logits in |network_states|
// and applies them to the input in the |evaluate_state|. Call this after
// producing output layers. On error, returns non-OK.
tensorflow
::
Status
Predict
(
const
NetworkStates
&
network_states
,
EvaluateState
*
evaluate_state
)
const
;
// Accessors.
bool
deterministic
()
const
{
return
deterministic_
;
}
bool
left_to_right
()
const
{
return
left_to_right_
;
}
const
SequenceLinkManager
&
sequence_link_manager
()
const
;
const
SequenceFeatureManager
&
sequence_feature_manager
()
const
;
private:
// Name of the component that this model is a part of.
string
component_name_
;
// Whether the underlying transition system is deterministic.
bool
deterministic_
=
false
;
// Whether to process sequences from left to right.
bool
left_to_right_
=
true
;
// Whether fixed or linked features are present.
bool
have_fixed_features_
=
false
;
bool
have_linked_features_
=
false
;
// Handle to the logits layer. Only used if |deterministic_| is false.
LayerHandle
<
float
>
logits_handle_
;
// Manager for sequence-based feature extractors.
SequenceFeatureManager
sequence_feature_manager_
;
// Manager for sequence-based linked embeddings.
SequenceLinkManager
sequence_link_manager_
;
// Sequence-based predictor, if |deterministic_| is false.
std
::
unique_ptr
<
SequencePredictor
>
sequence_predictor_
;
};
// Implementation details below.
inline
const
SequenceLinkManager
&
SequenceModel
::
sequence_link_manager
()
const
{
return
sequence_link_manager_
;
}
inline
const
SequenceFeatureManager
&
SequenceModel
::
sequence_feature_manager
()
const
{
return
sequence_feature_manager_
;
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_SEQUENCE_MODEL_H_
research/syntaxnet/dragnn/runtime/sequence_model_test.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_model.h"
#include <string>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/fixed_embeddings.h"
#include "dragnn/runtime/linked_embeddings.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/sequence_backend.h"
#include "dragnn/runtime/sequence_extractor.h"
#include "dragnn/runtime/sequence_linker.h"
#include "dragnn/runtime/sequence_predictor.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "syntaxnet/base.h"
#include <gmock/gmock.h>
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
using
::
testing
::
Return
;
constexpr
int
kNumSteps
=
50
;
constexpr
int
kVocabularySize
=
123
;
constexpr
int
kLinkedDim
=
11
;
constexpr
int
kLogitsDim
=
17
;
constexpr
char
kLogitsName
[]
=
"oddly_named_logits"
;
constexpr
char
kPreviousComponentName
[]
=
"previous_component"
;
constexpr
char
kPreviousLayerName
[]
=
"previous_layer"
;
constexpr
float
kPreviousLayerValue
=
-
1.0
;
// Sequence extractor that extracts [0, 2, 4, ...].
class
EvenNumbers
:
public
SequenceExtractor
{
public:
// Implements SequenceExtractor.
bool
Supports
(
const
FixedFeatureChannel
&
,
const
ComponentSpec
&
)
const
override
{
return
true
;
}
tensorflow
::
Status
Initialize
(
const
FixedFeatureChannel
&
,
const
ComponentSpec
&
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
GetIds
(
InputBatchCache
*
,
std
::
vector
<
int32
>
*
ids
)
const
override
{
ids
->
clear
();
for
(
int
i
=
0
;
i
<
num_steps_
;
++
i
)
ids
->
push_back
(
2
*
i
);
return
tensorflow
::
Status
::
OK
();
}
// Sets the number of steps to emit.
static
void
SetNumSteps
(
int
num_steps
)
{
num_steps_
=
num_steps
;
}
private:
// The number of steps to produce.
static
int
num_steps_
;
};
int
EvenNumbers
::
num_steps_
=
kNumSteps
;
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR
(
EvenNumbers
);
// Trivial linker that links each index to the previous one.
class
LinkToPrevious
:
public
SequenceLinker
{
public:
// Implements SequenceLinker.
bool
Supports
(
const
LinkedFeatureChannel
&
,
const
ComponentSpec
&
)
const
override
{
return
true
;
}
tensorflow
::
Status
Initialize
(
const
LinkedFeatureChannel
&
,
const
ComponentSpec
&
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
GetLinks
(
size_t
,
InputBatchCache
*
,
std
::
vector
<
int32
>
*
links
)
const
override
{
links
->
clear
();
for
(
int
i
=
0
;
i
<
num_steps_
;
++
i
)
links
->
push_back
(
i
-
1
);
return
tensorflow
::
Status
::
OK
();
}
// Sets the number of steps to emit.
static
void
SetNumSteps
(
int
num_steps
)
{
num_steps_
=
num_steps
;
}
private:
// The number of steps to produce.
static
int
num_steps_
;
};
int
LinkToPrevious
::
num_steps_
=
kNumSteps
;
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER
(
LinkToPrevious
);
// Trivial predictor that captures the prediction logits.
class
CaptureLogits
:
public
SequencePredictor
{
public:
// Implements SequenceLinker.
bool
Supports
(
const
ComponentSpec
&
)
const
override
{
return
true
;
}
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
Predict
(
Matrix
<
float
>
logits
,
InputBatchCache
*
)
const
override
{
GetLogits
()
=
logits
;
return
tensorflow
::
Status
::
OK
();
}
// Returns the captured logits.
static
Matrix
<
float
>
&
GetLogits
()
{
static
auto
*
logits
=
new
Matrix
<
float
>
();
return
*
logits
;
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_PREDICTOR
(
CaptureLogits
);
class
SequenceModelTest
:
public
NetworkTestBase
{
protected:
// Adds default call expectations. Since these are added first, they can be
// overridden by call expectations in individual tests.
SequenceModelTest
()
{
EXPECT_CALL
(
compute_session_
,
GetInputBatchCache
())
.
WillRepeatedly
(
Return
(
&
input_
));
EXPECT_CALL
(
compute_session_
,
GetReadiedComponent
(
kTestComponentName
))
.
WillRepeatedly
(
Return
(
&
backend_
));
// Some tests overwrite these; ensure that they are restored to the normal
// values at the start of each test.
EvenNumbers
::
SetNumSteps
(
kNumSteps
);
LinkToPrevious
::
SetNumSteps
(
kNumSteps
);
CaptureLogits
::
GetLogits
()
=
Matrix
<
float
>
();
}
// Initializes the |model_| and its underlying feature managers from the
// |component_spec|, then uses the |model_| to preprocess and predict the
// |input_|. Also sets each row of the logits to twice its row index. On
// error, returns non-OK.
tensorflow
::
Status
Run
(
ComponentSpec
component_spec
)
{
component_spec
.
set_name
(
kTestComponentName
);
AddComponent
(
kPreviousComponentName
);
AddLayer
(
kPreviousLayerName
,
kLinkedDim
);
AddComponent
(
kTestComponentName
);
AddLayer
(
kLogitsName
,
kLogitsDim
);
TF_RETURN_IF_ERROR
(
fixed_embedding_manager_
.
Reset
(
component_spec
,
&
variable_store_
,
&
network_state_manager_
));
TF_RETURN_IF_ERROR
(
linked_embedding_manager_
.
Reset
(
component_spec
,
&
variable_store_
,
&
network_state_manager_
));
TF_RETURN_IF_ERROR
(
model_
.
Initialize
(
component_spec
,
kLogitsName
,
&
fixed_embedding_manager_
,
&
linked_embedding_manager_
,
&
network_state_manager_
));
network_states_
.
Reset
(
&
network_state_manager_
);
StartComponent
(
kNumSteps
);
FillLayer
(
kPreviousComponentName
,
kPreviousLayerName
,
kPreviousLayerValue
);
StartComponent
(
0
);
TF_RETURN_IF_ERROR
(
model_
.
Preprocess
(
&
session_state_
,
&
compute_session_
,
&
evaluate_state_
));
MutableMatrix
<
float
>
logits
=
GetLayer
(
kTestComponentName
,
kLogitsName
);
for
(
int
row
=
0
;
row
<
logits
.
num_rows
();
++
row
)
{
for
(
int
column
=
0
;
column
<
logits
.
num_columns
();
++
column
)
{
logits
.
row
(
row
)[
column
]
=
2.0
*
row
;
}
}
return
model_
.
Predict
(
network_states_
,
&
evaluate_state_
);
}
// Returns the sequence size passed to the |backend_|.
int
GetBackendSequenceSize
()
{
// The sequence size is not directly exposed, but can be inferred using one
// of the reverse step translators.
return
backend_
.
GetStepLookupFunction
(
"reverse-token"
)(
0
,
0
,
0
)
+
1
;
}
// Fixed and linked embedding managers.
FixedEmbeddingManager
fixed_embedding_manager_
;
LinkedEmbeddingManager
linked_embedding_manager_
;
// Input batch injected into Preprocess() by default.
InputBatchCache
input_
;
// Backend injected into Preprocess().
SequenceBackend
backend_
;
// Sequence-based model.
SequenceModel
model_
;
// Per-evaluation state.
SequenceModel
::
EvaluateState
evaluate_state_
;
};
// Returns a ComponentSpec that is supported.
ComponentSpec
MakeSupportedSpec
()
{
ComponentSpec
component_spec
;
component_spec
.
set_num_actions
(
kLogitsDim
);
component_spec
.
mutable_component_builder
()
->
mutable_parameters
()
->
insert
(
{
"sequence_extractors"
,
"EvenNumbers"
});
component_spec
.
mutable_component_builder
()
->
mutable_parameters
()
->
insert
(
{
"sequence_linkers"
,
"LinkToPrevious"
});
component_spec
.
mutable_component_builder
()
->
mutable_parameters
()
->
insert
(
{
"sequence_predictor"
,
"CaptureLogits"
});
component_spec
.
mutable_backend
()
->
set_registered_name
(
"SequenceBackend"
);
FixedFeatureChannel
*
fixed_feature
=
component_spec
.
add_fixed_feature
();
fixed_feature
->
set_size
(
1
);
fixed_feature
->
set_embedding_dim
(
-
1
);
LinkedFeatureChannel
*
linked_feature
=
component_spec
.
add_linked_feature
();
linked_feature
->
set_source_component
(
kPreviousComponentName
);
linked_feature
->
set_source_layer
(
kPreviousLayerName
);
linked_feature
->
set_size
(
1
);
linked_feature
->
set_embedding_dim
(
-
1
);
return
component_spec
;
}
// Tests that the model supports a supported spec.
TEST_F
(
SequenceModelTest
,
Supported
)
{
const
ComponentSpec
component_spec
=
MakeSupportedSpec
();
EXPECT_TRUE
(
SequenceModel
::
Supports
(
component_spec
));
}
// Tests that the model rejects a spec with the wrong backend.
TEST_F
(
SequenceModelTest
,
UnsupportedBackend
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
mutable_backend
()
->
set_registered_name
(
"bad"
);
EXPECT_FALSE
(
SequenceModel
::
Supports
(
component_spec
));
}
// Tests that the model rejects a spec with no features.
TEST_F
(
SequenceModelTest
,
UnsupportedNoFeatures
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
clear_fixed_feature
();
component_spec
.
clear_linked_feature
();
EXPECT_FALSE
(
SequenceModel
::
Supports
(
component_spec
));
}
// Tests that the model rejects a spec with a multi-embedding fixed feature.
TEST_F
(
SequenceModelTest
,
UnsupportedMultiEmbeddingFixedFeature
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
mutable_fixed_feature
(
0
)
->
set_size
(
2
);
EXPECT_FALSE
(
SequenceModel
::
Supports
(
component_spec
));
}
// Tests that the model rejects a spec with a multi-embedding linked feature.
TEST_F
(
SequenceModelTest
,
UnsupportedMultiEmbeddingLinkedFeature
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
mutable_linked_feature
(
0
)
->
set_size
(
2
);
EXPECT_FALSE
(
SequenceModel
::
Supports
(
component_spec
));
}
// Tests that the model rejects a spec with only recurrent links.
TEST_F
(
SequenceModelTest
,
UnsupportedOnlyRecurrentLinks
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
set_name
(
"foo"
);
component_spec
.
clear_fixed_feature
();
component_spec
.
mutable_linked_feature
(
0
)
->
set_source_component
(
"foo"
);
EXPECT_FALSE
(
SequenceModel
::
Supports
(
component_spec
));
}
// Tests that Initialize() succeeds on a supported spec.
TEST_F
(
SequenceModelTest
,
InitializeSupported
)
{
const
ComponentSpec
component_spec
=
MakeSupportedSpec
();
TF_ASSERT_OK
(
Run
(
component_spec
));
EXPECT_FALSE
(
model_
.
deterministic
());
EXPECT_TRUE
(
model_
.
left_to_right
());
EXPECT_EQ
(
model_
.
sequence_feature_manager
().
num_channels
(),
1
);
EXPECT_EQ
(
model_
.
sequence_link_manager
().
num_channels
(),
1
);
}
// Tests that Initialize() detects deterministic components.
TEST_F
(
SequenceModelTest
,
InitializeDeterministic
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
set_num_actions
(
1
);
TF_ASSERT_OK
(
Run
(
component_spec
));
EXPECT_TRUE
(
model_
.
deterministic
());
EXPECT_TRUE
(
model_
.
left_to_right
());
EXPECT_EQ
(
model_
.
sequence_feature_manager
().
num_channels
(),
1
);
EXPECT_EQ
(
model_
.
sequence_link_manager
().
num_channels
(),
1
);
}
// Tests that Initialize() detects right-to-left components.
TEST_F
(
SequenceModelTest
,
InitializeLeftToRight
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
mutable_transition_system
()
->
mutable_parameters
()
->
insert
(
{
"left_to_right"
,
"false"
});
TF_ASSERT_OK
(
Run
(
component_spec
));
EXPECT_FALSE
(
model_
.
deterministic
());
EXPECT_FALSE
(
model_
.
left_to_right
());
EXPECT_EQ
(
model_
.
sequence_feature_manager
().
num_channels
(),
1
);
EXPECT_EQ
(
model_
.
sequence_link_manager
().
num_channels
(),
1
);
}
// Tests that Initialize() fails if the backend is wrong.
TEST_F
(
SequenceModelTest
,
WrongBackend
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
mutable_backend
()
->
set_registered_name
(
"bad"
);
EXPECT_THAT
(
Run
(
component_spec
),
test
::
IsErrorWithSubstr
(
"Invalid component backend"
));
}
// Tests that Initialize() fails if the number of actions in the ComponentSpec
// does not match the logits.
TEST_F
(
SequenceModelTest
,
WrongNumActions
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
set_num_actions
(
kLogitsDim
+
1
);
EXPECT_THAT
(
Run
(
component_spec
),
test
::
IsErrorWithSubstr
(
"Logits dimension mismatch"
));
}
// Tests that Initialize() fails if an unknown sequence extractor is specified.
TEST_F
(
SequenceModelTest
,
UnknownSequenceExtractor
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
(
*
component_spec
.
mutable_component_builder
()
->
mutable_parameters
())[
"sequence_extractors"
]
=
"bad"
;
EXPECT_THAT
(
Run
(
component_spec
),
test
::
IsErrorWithSubstr
(
"Unknown DRAGNN Runtime Sequence Extractor"
));
}
// Tests that Initialize() fails if an unknown sequence linker is specified.
TEST_F
(
SequenceModelTest
,
UnknownSequenceLinker
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
(
*
component_spec
.
mutable_component_builder
()
->
mutable_parameters
())[
"sequence_linkers"
]
=
"bad"
;
EXPECT_THAT
(
Run
(
component_spec
),
test
::
IsErrorWithSubstr
(
"Unknown DRAGNN Runtime Sequence Linker"
));
}
// Tests that Initialize() fails if an unknown sequence predictor is specified.
TEST_F
(
SequenceModelTest
,
UnknownSequencePredictor
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
(
*
component_spec
.
mutable_component_builder
()
->
mutable_parameters
())[
"sequence_predictor"
]
=
"bad"
;
EXPECT_THAT
(
Run
(
component_spec
),
test
::
IsErrorWithSubstr
(
"Unknown DRAGNN Runtime Sequence Predictor"
));
}
// Tests that Initialize() fails on an unknown component builder parameter.
TEST_F
(
SequenceModelTest
,
UnknownComponentBuilderParameter
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
(
*
component_spec
.
mutable_component_builder
()
->
mutable_parameters
())[
"bad"
]
=
"bad"
;
EXPECT_THAT
(
Run
(
component_spec
),
test
::
IsErrorWithSubstr
(
"Unknown attribute"
));
}
// Tests that Initialize() fails if there are no fixed or linked features.
TEST_F
(
SequenceModelTest
,
InitializeRequiresFeatures
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
clear_fixed_feature
();
component_spec
.
clear_linked_feature
();
(
*
component_spec
.
mutable_component_builder
()
->
mutable_parameters
())[
"sequence_extractors"
]
=
""
;
(
*
component_spec
.
mutable_component_builder
()
->
mutable_parameters
())[
"sequence_linkers"
]
=
""
;
EXPECT_THAT
(
Run
(
component_spec
),
test
::
IsErrorWithSubstr
(
"No fixed or linked features"
));
}
// Tests that the model fails if a null batch is returned.
TEST_F
(
SequenceModelTest
,
NullBatch
)
{
EXPECT_CALL
(
compute_session_
,
GetInputBatchCache
()).
WillOnce
(
Return
(
nullptr
));
EXPECT_THAT
(
Run
(
MakeSupportedSpec
()),
test
::
IsErrorWithSubstr
(
"Null input batch"
));
}
// Tests that the model properly sets up the EvaluateState and logits.
TEST_F
(
SequenceModelTest
,
Success
)
{
TF_ASSERT_OK
(
Run
(
MakeSupportedSpec
()));
EXPECT_EQ
(
GetBackendSequenceSize
(),
kNumSteps
);
EXPECT_EQ
(
evaluate_state_
.
num_steps
,
kNumSteps
);
EXPECT_EQ
(
evaluate_state_
.
input
,
&
input_
);
EXPECT_EQ
(
evaluate_state_
.
features
.
num_channels
(),
1
);
EXPECT_EQ
(
evaluate_state_
.
features
.
num_steps
(),
kNumSteps
);
EXPECT_EQ
(
evaluate_state_
.
features
.
GetId
(
0
,
0
),
0
);
EXPECT_EQ
(
evaluate_state_
.
features
.
GetId
(
0
,
1
),
2
);
EXPECT_EQ
(
evaluate_state_
.
features
.
GetId
(
0
,
2
),
4
);
EXPECT_EQ
(
evaluate_state_
.
links
.
num_channels
(),
1
);
EXPECT_EQ
(
evaluate_state_
.
links
.
num_steps
(),
kNumSteps
);
Vector
<
float
>
embedding
;
bool
is_out_of_bounds
=
false
;
evaluate_state_
.
links
.
Get
(
0
,
0
,
&
embedding
,
&
is_out_of_bounds
);
ExpectVector
(
embedding
,
kLinkedDim
,
0.0
);
EXPECT_TRUE
(
is_out_of_bounds
);
evaluate_state_
.
links
.
Get
(
0
,
1
,
&
embedding
,
&
is_out_of_bounds
);
ExpectVector
(
embedding
,
kLinkedDim
,
kPreviousLayerValue
);
EXPECT_FALSE
(
is_out_of_bounds
);
const
Matrix
<
float
>
logits
=
CaptureLogits
::
GetLogits
();
ASSERT_EQ
(
logits
.
num_rows
(),
kNumSteps
);
ASSERT_EQ
(
logits
.
num_columns
(),
kLogitsDim
);
for
(
int
i
=
0
;
i
<
kNumSteps
;
++
i
)
{
ExpectVector
(
logits
.
row
(
i
),
kLogitsDim
,
2.0
*
i
);
}
}
// Tests that the model works with only fixed features.
TEST_F
(
SequenceModelTest
,
FixedFeaturesOnly
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
clear_linked_feature
();
(
*
component_spec
.
mutable_component_builder
()
->
mutable_parameters
())[
"sequence_linkers"
]
=
""
;
TF_ASSERT_OK
(
Run
(
component_spec
));
EXPECT_EQ
(
GetBackendSequenceSize
(),
kNumSteps
);
EXPECT_EQ
(
evaluate_state_
.
num_steps
,
kNumSteps
);
EXPECT_EQ
(
evaluate_state_
.
input
,
&
input_
);
EXPECT_EQ
(
evaluate_state_
.
features
.
num_channels
(),
1
);
EXPECT_EQ
(
evaluate_state_
.
features
.
num_steps
(),
kNumSteps
);
EXPECT_EQ
(
evaluate_state_
.
features
.
GetId
(
0
,
0
),
0
);
EXPECT_EQ
(
evaluate_state_
.
features
.
GetId
(
0
,
1
),
2
);
EXPECT_EQ
(
evaluate_state_
.
features
.
GetId
(
0
,
2
),
4
);
EXPECT_EQ
(
evaluate_state_
.
links
.
num_channels
(),
0
);
EXPECT_EQ
(
evaluate_state_
.
links
.
num_steps
(),
0
);
const
Matrix
<
float
>
logits
=
CaptureLogits
::
GetLogits
();
ASSERT_EQ
(
logits
.
num_rows
(),
kNumSteps
);
ASSERT_EQ
(
logits
.
num_columns
(),
kLogitsDim
);
for
(
int
i
=
0
;
i
<
kNumSteps
;
++
i
)
{
ExpectVector
(
logits
.
row
(
i
),
kLogitsDim
,
2.0
*
i
);
}
}
// Tests that the model works with only linked features.
TEST_F
(
SequenceModelTest
,
LinkedFeaturesOnly
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
clear_fixed_feature
();
(
*
component_spec
.
mutable_component_builder
()
->
mutable_parameters
())[
"sequence_extractors"
]
=
""
;
TF_ASSERT_OK
(
Run
(
component_spec
));
EXPECT_EQ
(
GetBackendSequenceSize
(),
kNumSteps
);
EXPECT_EQ
(
evaluate_state_
.
num_steps
,
kNumSteps
);
EXPECT_EQ
(
evaluate_state_
.
input
,
&
input_
);
EXPECT_EQ
(
evaluate_state_
.
features
.
num_channels
(),
0
);
EXPECT_EQ
(
evaluate_state_
.
features
.
num_steps
(),
0
);
EXPECT_EQ
(
evaluate_state_
.
links
.
num_channels
(),
1
);
EXPECT_EQ
(
evaluate_state_
.
links
.
num_steps
(),
kNumSteps
);
Vector
<
float
>
embedding
;
bool
is_out_of_bounds
=
false
;
evaluate_state_
.
links
.
Get
(
0
,
0
,
&
embedding
,
&
is_out_of_bounds
);
ExpectVector
(
embedding
,
kLinkedDim
,
0.0
);
EXPECT_TRUE
(
is_out_of_bounds
);
evaluate_state_
.
links
.
Get
(
0
,
1
,
&
embedding
,
&
is_out_of_bounds
);
ExpectVector
(
embedding
,
kLinkedDim
,
kPreviousLayerValue
);
EXPECT_FALSE
(
is_out_of_bounds
);
const
Matrix
<
float
>
logits
=
CaptureLogits
::
GetLogits
();
ASSERT_EQ
(
logits
.
num_rows
(),
kNumSteps
);
ASSERT_EQ
(
logits
.
num_columns
(),
kLogitsDim
);
for
(
int
i
=
0
;
i
<
kNumSteps
;
++
i
)
{
ExpectVector
(
logits
.
row
(
i
),
kLogitsDim
,
2.0
*
i
);
}
}
// Tests that the model fails if the fixed and linked features disagree on the
// number of steps.
TEST_F
(
SequenceModelTest
,
FixedAndLinkedDisagree
)
{
EvenNumbers
::
SetNumSteps
(
5
);
LinkToPrevious
::
SetNumSteps
(
6
);
EXPECT_THAT
(
Run
(
MakeSupportedSpec
()),
test
::
IsErrorWithSubstr
(
"Sequence length mismatch between fixed "
"features (5) and linked features (6)"
));
}
// Tests that the model can handle an empty sequence.
TEST_F
(
SequenceModelTest
,
EmptySequence
)
{
EvenNumbers
::
SetNumSteps
(
0
);
LinkToPrevious
::
SetNumSteps
(
0
);
TF_ASSERT_OK
(
Run
(
MakeSupportedSpec
()));
EXPECT_EQ
(
GetBackendSequenceSize
(),
0
);
const
Matrix
<
float
>
logits
=
CaptureLogits
::
GetLogits
();
ASSERT_EQ
(
logits
.
num_rows
(),
0
);
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/sequence_predictor.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_predictor.h"
#include "tensorflow/core/lib/core/errors.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
tensorflow
::
Status
SequencePredictor
::
Select
(
const
ComponentSpec
&
component_spec
,
string
*
name
)
{
string
supporting_name
;
for
(
const
Registry
::
Registrar
*
registrar
=
registry
()
->
components
;
registrar
!=
nullptr
;
registrar
=
registrar
->
next
())
{
Factory
*
factory_function
=
registrar
->
object
();
std
::
unique_ptr
<
SequencePredictor
>
current_predictor
(
factory_function
());
if
(
!
current_predictor
->
Supports
(
component_spec
))
continue
;
if
(
!
supporting_name
.
empty
())
{
return
tensorflow
::
errors
::
Internal
(
"Multiple SequencePredictors support ComponentSpec ("
,
supporting_name
,
" and "
,
registrar
->
name
(),
"): "
,
component_spec
.
ShortDebugString
());
}
supporting_name
=
registrar
->
name
();
}
if
(
supporting_name
.
empty
())
{
return
tensorflow
::
errors
::
NotFound
(
"No SequencePredictor supports ComponentSpec: "
,
component_spec
.
ShortDebugString
());
}
// Success; make modifications.
*
name
=
supporting_name
;
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
SequencePredictor
::
New
(
const
string
&
name
,
const
ComponentSpec
&
component_spec
,
std
::
unique_ptr
<
SequencePredictor
>
*
predictor
)
{
std
::
unique_ptr
<
SequencePredictor
>
matching_predictor
;
TF_RETURN_IF_ERROR
(
SequencePredictor
::
CreateOrError
(
name
,
&
matching_predictor
));
TF_RETURN_IF_ERROR
(
matching_predictor
->
Initialize
(
component_spec
));
// Success; make modifications.
*
predictor
=
std
::
move
(
matching_predictor
);
return
tensorflow
::
Status
::
OK
();
}
}
// namespace runtime
}
// namespace dragnn
REGISTER_SYNTAXNET_CLASS_REGISTRY
(
"DRAGNN Runtime Sequence Predictor"
,
dragnn
::
runtime
::
SequencePredictor
);
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/sequence_predictor.h
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_SEQUENCE_PREDICTOR_H_
#define DRAGNN_RUNTIME_SEQUENCE_PREDICTOR_H_
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/math/types.h"
#include "syntaxnet/base.h"
#include "syntaxnet/registry.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// Interface for making predictions on sequences.
//
// This predictor can be used to avoid ComputeSession overhead in simple cases;
// for example, predicting sequences of POS tags.
class
SequencePredictor
:
public
RegisterableClass
<
SequencePredictor
>
{
public:
// Sets |predictor| to an instance of the subclass named |name| initialized
// from the |component_spec|. On error, returns non-OK and modifies nothing.
static
tensorflow
::
Status
New
(
const
string
&
name
,
const
ComponentSpec
&
component_spec
,
std
::
unique_ptr
<
SequencePredictor
>
*
predictor
);
SequencePredictor
(
const
SequencePredictor
&
)
=
delete
;
SequencePredictor
&
operator
=
(
const
SequencePredictor
&
)
=
delete
;
virtual
~
SequencePredictor
()
=
default
;
// Sets |name| to the registered name of the SequencePredictor that supports
// the |component_spec|. On error, returns non-OK and modifies nothing. The
// returned statuses include:
// * OK: If a supporting SequencePredictor was found.
// * INTERNAL: If an error occurred while searching for a compatible match.
// * NOT_FOUND: If the search was error-free, but no compatible match was
// found.
static
tensorflow
::
Status
Select
(
const
ComponentSpec
&
component_spec
,
string
*
name
);
// Makes a sequence of predictions using the per-step |logits| and writes
// annotations to the |input|.
virtual
tensorflow
::
Status
Predict
(
Matrix
<
float
>
logits
,
InputBatchCache
*
input
)
const
=
0
;
protected:
SequencePredictor
()
=
default
;
private:
// Helps prevent use of the Create() method; use New() instead.
using
RegisterableClass
<
SequencePredictor
>::
Create
;
// Returns true if this supports the |component_spec|. Implementations must
// coordinate to ensure that at most one supports any given |component_spec|.
virtual
bool
Supports
(
const
ComponentSpec
&
component_spec
)
const
=
0
;
// Initializes this from the |component_spec|. On error, returns non-OK.
virtual
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
component_spec
)
=
0
;
};
}
// namespace runtime
}
// namespace dragnn
DECLARE_SYNTAXNET_CLASS_REGISTRY
(
"DRAGNN Runtime Sequence Predictor"
,
dragnn
::
runtime
::
SequencePredictor
);
}
// namespace syntaxnet
#define DRAGNN_RUNTIME_REGISTER_SEQUENCE_PREDICTOR(subclass) \
REGISTER_SYNTAXNET_CLASS_COMPONENT( \
::syntaxnet::dragnn::runtime::SequencePredictor, #subclass, subclass)
#endif // DRAGNN_RUNTIME_SEQUENCE_PREDICTOR_H_
research/syntaxnet/dragnn/runtime/sequence_predictor_test.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_predictor.h"
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/math/types.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Supports components named "success" and initializes successfully.
class
Success
:
public
SequencePredictor
{
public:
// Implements SequencePredictor.
bool
Supports
(
const
ComponentSpec
&
component_spec
)
const
override
{
return
component_spec
.
name
()
==
"success"
;
}
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
Predict
(
Matrix
<
float
>
,
InputBatchCache
*
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_PREDICTOR
(
Success
);
// Supports components named "failure" and fails to initialize.
class
Failure
:
public
SequencePredictor
{
public:
// Implements SequencePredictor.
bool
Supports
(
const
ComponentSpec
&
component_spec
)
const
override
{
return
component_spec
.
name
()
==
"failure"
;
}
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
)
override
{
return
tensorflow
::
errors
::
Internal
(
"Boom!"
);
}
tensorflow
::
Status
Predict
(
Matrix
<
float
>
,
InputBatchCache
*
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_PREDICTOR
(
Failure
);
// Supports components named "duplicate" and initializes successfully.
class
Duplicate
:
public
SequencePredictor
{
public:
// Implements SequencePredictor.
bool
Supports
(
const
ComponentSpec
&
component_spec
)
const
override
{
return
component_spec
.
name
()
==
"duplicate"
;
}
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
Predict
(
Matrix
<
float
>
,
InputBatchCache
*
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_PREDICTOR
(
Duplicate
);
// Duplicate of the above.
using
Duplicate2
=
Duplicate
;
DRAGNN_RUNTIME_REGISTER_SEQUENCE_PREDICTOR
(
Duplicate2
);
// Tests that a component can be successfully created.
TEST
(
SequencePredictorTest
,
Success
)
{
string
name
;
std
::
unique_ptr
<
SequencePredictor
>
predictor
;
ComponentSpec
component_spec
;
component_spec
.
set_name
(
"success"
);
TF_ASSERT_OK
(
SequencePredictor
::
Select
(
component_spec
,
&
name
));
ASSERT_EQ
(
name
,
"Success"
);
TF_EXPECT_OK
(
SequencePredictor
::
New
(
name
,
component_spec
,
&
predictor
));
EXPECT_NE
(
predictor
,
nullptr
);
}
// Tests that errors in Initialize() are reported.
TEST
(
SequencePredictorTest
,
FailToInitialize
)
{
string
name
;
std
::
unique_ptr
<
SequencePredictor
>
predictor
;
ComponentSpec
component_spec
;
component_spec
.
set_name
(
"failure"
);
TF_ASSERT_OK
(
SequencePredictor
::
Select
(
component_spec
,
&
name
));
EXPECT_EQ
(
name
,
"Failure"
);
EXPECT_THAT
(
SequencePredictor
::
New
(
name
,
component_spec
,
&
predictor
),
test
::
IsErrorWithSubstr
(
"Boom!"
));
EXPECT_EQ
(
predictor
,
nullptr
);
}
// Tests that unsupported specs are reported as NOT_FOUND errors.
TEST
(
SequencePredictorTest
,
UnsupportedSpec
)
{
string
name
=
"not overwritten"
;
ComponentSpec
component_spec
;
component_spec
.
set_name
(
"unsupported"
);
EXPECT_THAT
(
SequencePredictor
::
Select
(
component_spec
,
&
name
),
test
::
IsErrorWithCodeAndSubstr
(
tensorflow
::
error
::
NOT_FOUND
,
"No SequencePredictor supports ComponentSpec"
));
EXPECT_EQ
(
name
,
"not overwritten"
);
}
// Tests that unsupported subclass names are reported as errors.
TEST
(
SequencePredictorTest
,
UnsupportedSubclass
)
{
std
::
unique_ptr
<
SequencePredictor
>
predictor
;
ComponentSpec
component_spec
;
EXPECT_THAT
(
SequencePredictor
::
New
(
"Unsupported"
,
component_spec
,
&
predictor
),
test
::
IsErrorWithSubstr
(
"Unknown DRAGNN Runtime Sequence Predictor"
));
EXPECT_EQ
(
predictor
,
nullptr
);
}
// Tests that multiple supporting predictors are reported as INTERNAL errors.
TEST
(
SequencePredictorTest
,
Duplicate
)
{
string
name
=
"not overwritten"
;
ComponentSpec
component_spec
;
component_spec
.
set_name
(
"duplicate"
);
EXPECT_THAT
(
SequencePredictor
::
Select
(
component_spec
,
&
name
),
test
::
IsErrorWithCodeAndSubstr
(
tensorflow
::
error
::
INTERNAL
,
"Multiple SequencePredictors support ComponentSpec"
));
EXPECT_EQ
(
name
,
"not overwritten"
);
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
Prev
1
…
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment