Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
a4bb31d0
Commit
a4bb31d0
authored
May 02, 2018
by
Terry Koo
Browse files
Export @195097388.
parent
dea7ecf6
Changes
296
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2630 additions
and
0 deletions
+2630
-0
research/syntaxnet/dragnn/runtime/clear_dropout_component_transformer_test.cc
...ragnn/runtime/clear_dropout_component_transformer_test.cc
+62
-0
research/syntaxnet/dragnn/runtime/component.cc
research/syntaxnet/dragnn/runtime/component.cc
+107
-0
research/syntaxnet/dragnn/runtime/component.h
research/syntaxnet/dragnn/runtime/component.h
+111
-0
research/syntaxnet/dragnn/runtime/component_test.cc
research/syntaxnet/dragnn/runtime/component_test.cc
+201
-0
research/syntaxnet/dragnn/runtime/component_transformation.cc
...arch/syntaxnet/dragnn/runtime/component_transformation.cc
+91
-0
research/syntaxnet/dragnn/runtime/component_transformation.h
research/syntaxnet/dragnn/runtime/component_transformation.h
+86
-0
research/syntaxnet/dragnn/runtime/component_transformation_test.cc
...syntaxnet/dragnn/runtime/component_transformation_test.cc
+241
-0
research/syntaxnet/dragnn/runtime/conversion.cc
research/syntaxnet/dragnn/runtime/conversion.cc
+82
-0
research/syntaxnet/dragnn/runtime/conversion.h
research/syntaxnet/dragnn/runtime/conversion.h
+58
-0
research/syntaxnet/dragnn/runtime/conversion_test.cc
research/syntaxnet/dragnn/runtime/conversion_test.cc
+140
-0
research/syntaxnet/dragnn/runtime/converter.cc
research/syntaxnet/dragnn/runtime/converter.cc
+145
-0
research/syntaxnet/dragnn/runtime/converter_test.sh
research/syntaxnet/dragnn/runtime/converter_test.sh
+92
-0
research/syntaxnet/dragnn/runtime/dynamic_component.cc
research/syntaxnet/dragnn/runtime/dynamic_component.cc
+173
-0
research/syntaxnet/dragnn/runtime/dynamic_component_test.cc
research/syntaxnet/dragnn/runtime/dynamic_component_test.cc
+193
-0
research/syntaxnet/dragnn/runtime/extensions.cc
research/syntaxnet/dragnn/runtime/extensions.cc
+81
-0
research/syntaxnet/dragnn/runtime/extensions.h
research/syntaxnet/dragnn/runtime/extensions.h
+233
-0
research/syntaxnet/dragnn/runtime/extensions_test.cc
research/syntaxnet/dragnn/runtime/extensions_test.cc
+266
-0
research/syntaxnet/dragnn/runtime/feed_forward_network.cc
research/syntaxnet/dragnn/runtime/feed_forward_network.cc
+90
-0
research/syntaxnet/dragnn/runtime/feed_forward_network_kernel.cc
...h/syntaxnet/dragnn/runtime/feed_forward_network_kernel.cc
+114
-0
research/syntaxnet/dragnn/runtime/feed_forward_network_kernel.h
...ch/syntaxnet/dragnn/runtime/feed_forward_network_kernel.h
+64
-0
No files found.
Too many changes to show.
To preserve performance only
296 of 296+
files are displayed.
Plain diff
Email patch
research/syntaxnet/dragnn/runtime/clear_dropout_component_transformer_test.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component_transformation.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Tests that a spec with no dropout features is unmodified.
TEST
(
ClearDropoutComponentTransformerTest
,
DoesNotModifyIfNoDropout
)
{
ComponentSpec
component_spec
;
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"foo"
);
component_spec
.
add_fixed_feature
()
->
set_name
(
"words"
);
const
ComponentSpec
expected_spec
=
component_spec
;
TF_ASSERT_OK
(
ComponentTransformer
::
ApplyAll
(
&
component_spec
));
EXPECT_THAT
(
component_spec
,
test
::
EqualsProto
(
expected_spec
));
}
// Tests that a spec with dropout features is modified.
TEST
(
ClearDropoutComponentTransformerTest
,
ClearsDropout
)
{
ComponentSpec
component_spec
;
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"foo"
);
FixedFeatureChannel
*
channel
=
component_spec
.
add_fixed_feature
();
channel
->
set_name
(
"words"
);
channel
->
set_dropout_id
(
100
);
channel
->
add_dropout_keep_probability
(
1.0
);
channel
->
add_dropout_keep_probability
(
0.5
);
channel
->
add_dropout_keep_probability
(
0.1
);
ComponentSpec
expected_spec
=
component_spec
;
expected_spec
.
clear_fixed_feature
();
expected_spec
.
add_fixed_feature
()
->
set_name
(
"words"
);
TF_ASSERT_OK
(
ComponentTransformer
::
ApplyAll
(
&
component_spec
));
EXPECT_THAT
(
component_spec
,
test
::
EqualsProto
(
expected_spec
));
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/component.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/component.h"
#include <memory>
#include <utility>
#include <vector>
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/str_util.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
string
GetNormalizedComponentBuilderName
(
const
ComponentSpec
&
component_spec
)
{
// The Python registration API is based on (relative) module paths, such as
// "some.module.FooComponent". Discard the module path prefix and use only
// the final segment, which is the subclass name.
const
std
::
vector
<
string
>
segments
=
tensorflow
::
str_util
::
Split
(
component_spec
.
component_builder
().
registered_name
(),
"."
);
CHECK_GT
(
segments
.
size
(),
0
)
<<
"No builder name for component spec: "
<<
component_spec
.
ShortDebugString
();
tensorflow
::
StringPiece
subclass_name
=
segments
.
back
();
// In addition, remove a "Builder" suffix, if any. In the Python codebase, a
// ComponentBuilder builds a TF graph to perform some computation, whereas in
// the runtime, a Component directly executes that computation.
tensorflow
::
str_util
::
ConsumeSuffix
(
&
subclass_name
,
"Builder"
);
return
subclass_name
.
ToString
();
}
tensorflow
::
Status
Component
::
Select
(
const
ComponentSpec
&
spec
,
string
*
result
)
{
const
string
normalized_builder_name
=
GetNormalizedComponentBuilderName
(
spec
);
// Iterate through all registered components, constructing them and querying
// their Supports() methods.
std
::
unique_ptr
<
Component
>
current_best
;
string
current_best_name
;
for
(
const
Registry
::
Registrar
*
component
=
registry
()
->
components
;
component
!=
nullptr
;
component
=
component
->
next
())
{
// component->object() is a function pointer to the subclass' constructor.
std
::
unique_ptr
<
Component
>
next
(
component
->
object
()());
string
next_name
(
component
->
name
());
if
(
!
next
->
Supports
(
spec
,
normalized_builder_name
))
{
continue
;
}
// First supported component.
if
(
current_best
==
nullptr
)
{
current_best
=
std
::
move
(
next
);
current_best_name
=
next_name
;
continue
;
}
// The two must agree on which takes precedence.
if
(
next
->
PreferredTo
(
*
current_best
))
{
if
(
current_best
->
PreferredTo
(
*
next
))
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Classes '"
,
current_best_name
,
"' and '"
,
next_name
,
"' both think they should be preferred to each-other. Please "
"add logic to their PreferredTo() methods to avoid this."
);
}
current_best
=
std
::
move
(
next
);
current_best_name
=
next_name
;
}
else
if
(
!
current_best
->
PreferredTo
(
*
next
))
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Classes '"
,
current_best_name
,
"' and '"
,
next_name
,
"' both think they should be dis-preferred to each-other. Please "
"add logic to their PreferredTo() methods to avoid this."
);
}
}
if
(
current_best
==
nullptr
)
{
return
tensorflow
::
errors
::
NotFound
(
"Could not find a best spec for component '"
,
spec
.
name
(),
"' with normalized builder name '"
,
normalized_builder_name
,
"'"
);
}
else
{
*
result
=
std
::
move
(
current_best_name
);
return
tensorflow
::
Status
::
OK
();
}
}
}
// namespace runtime
}
// namespace dragnn
REGISTER_SYNTAXNET_CLASS_REGISTRY
(
"DRAGNN Runtime Component"
,
dragnn
::
runtime
::
Component
);
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/component.h
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_COMPONENT_H_
#define DRAGNN_RUNTIME_COMPONENT_H_
#include <string>
#include "dragnn/core/compute_session.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "syntaxnet/registry.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// Helper method, currently only used by myelination.cc.
string
GetNormalizedComponentBuilderName
(
const
ComponentSpec
&
component_spec
);
// Interface for components.
class
Component
:
public
RegisterableClass
<
Component
>
{
public:
Component
(
const
Component
&
that
)
=
delete
;
Component
&
operator
=
(
const
Component
&
that
)
=
delete
;
virtual
~
Component
()
=
default
;
// Initializes this to the configuration in the |component_spec|. Retrieves
// pre-trained variables from the |variable_store|, which must outlive this.
// Adds layers and local operands to the |network_state_manager|, which must
// be positioned at the current component. Requests SessionState extensions
// from the |extension_manager|. On error, returns non-OK.
virtual
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
=
0
;
// Evaluates this on the |session_state| and |compute_session|, which must
// both be positioned at the current component. If |component_trace| is
// non-null, overwrites it with extracted traces. On error, returns non-OK.
virtual
tensorflow
::
Status
Evaluate
(
SessionState
*
session_state
,
ComputeSession
*
compute_session
,
ComponentTrace
*
component_trace
)
const
=
0
;
// Returns the best component for a spec, searching through all registered
// subclasses. This allows specialized implementations to be used.
//
// Sets |result| on success, otherwise returns an error message if a single
// best matching component could not be found. Returned statuses include:
// * OK: If a single best matching component was found.
// * FAILED_PRECONDITION: If an error occurred during the search.
// * NOT_FOUND: If the search was error-free, but no matches were found.
static
tensorflow
::
Status
Select
(
const
ComponentSpec
&
spec
,
string
*
result
);
protected:
Component
()
=
default
;
// Whether this component supports a given spec. |spec| is the full component
// spec and |normalized_builder_name| is the component builder name, with
// Python modules and the suffix "Builder" stripped.
virtual
bool
Supports
(
const
ComponentSpec
&
spec
,
const
string
&
normalized_builder_name
)
const
=
0
;
// Whether to prefer this component to another. (Both components must say that
// they support the spec.)
//
// Components must agree on whether they are more or less specialized than
// another component. Feel free to expose methods for subclasses to identify
// themselves; currently, we only have unoptimized implementations (which say
// they are never preferred) and optimized implementations (which say they are
// always preferred).
virtual
bool
PreferredTo
(
const
Component
&
other
)
const
=
0
;
private:
// Helps prevent use of the Create() method; use CreateOrError() instead.
using
RegisterableClass
<
Component
>::
Create
;
};
}
// namespace runtime
}
// namespace dragnn
DECLARE_SYNTAXNET_CLASS_REGISTRY
(
"DRAGNN Runtime Component"
,
dragnn
::
runtime
::
Component
);
}
// namespace syntaxnet
// Registers a subclass using its class name as a string.
#define DRAGNN_RUNTIME_REGISTER_COMPONENT(subclass) \
REGISTER_SYNTAXNET_CLASS_COMPONENT(::syntaxnet::dragnn::runtime::Component, \
#subclass, subclass)
#endif // DRAGNN_RUNTIME_COMPONENT_H_
research/syntaxnet/dragnn/runtime/component_test.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/component.h"
#include <memory>
#include <string>
#include "dragnn/core/compute_session.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/variable_store.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Expects that the two pointers have the same address.
void
ExpectSameAddress
(
const
void
*
pointer1
,
const
void
*
pointer2
)
{
EXPECT_EQ
(
pointer1
,
pointer2
);
}
// A trivial implementation for tests.
class
FooComponent
:
public
Component
{
public:
// Implements Component.
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
Evaluate
(
SessionState
*
session_state
,
ComputeSession
*
compute_session
,
ComponentTrace
*
component_trace
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
bool
Supports
(
const
ComponentSpec
&
spec
,
const
string
&
normalized_builder_name
)
const
override
{
return
normalized_builder_name
==
"FooComponent"
;
}
bool
PreferredTo
(
const
Component
&
other
)
const
override
{
return
false
;
}
};
DRAGNN_RUNTIME_REGISTER_COMPONENT
(
FooComponent
);
// Class that always says it's preferred.
class
ImTheBest1
:
public
FooComponent
{
public:
bool
Supports
(
const
ComponentSpec
&
spec
,
const
string
&
normalized_builder_name
)
const
override
{
return
normalized_builder_name
==
"ImTheBest"
;
}
bool
PreferredTo
(
const
Component
&
other
)
const
override
{
return
true
;
}
};
class
ImTheBest2
:
public
ImTheBest1
{};
DRAGNN_RUNTIME_REGISTER_COMPONENT
(
ImTheBest1
);
DRAGNN_RUNTIME_REGISTER_COMPONENT
(
ImTheBest2
);
// Class that always says it's dispreferred.
class
ImTheWorst1
:
public
FooComponent
{
public:
bool
Supports
(
const
ComponentSpec
&
spec
,
const
string
&
normalized_builder_name
)
const
override
{
return
normalized_builder_name
==
"ImTheWorst"
;
}
bool
PreferredTo
(
const
Component
&
other
)
const
override
{
return
false
;
}
};
class
ImTheWorst2
:
public
ImTheWorst1
{};
DRAGNN_RUNTIME_REGISTER_COMPONENT
(
ImTheWorst1
);
DRAGNN_RUNTIME_REGISTER_COMPONENT
(
ImTheWorst2
);
// Specialized foo implementation. We use debug-mode down-casting to check that
// the correct sub-class was instantiated.
class
SpecializedFooComponent
:
public
Component
{
public:
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
Evaluate
(
SessionState
*
session_state
,
ComputeSession
*
compute_session
,
ComponentTrace
*
component_trace
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
bool
Supports
(
const
ComponentSpec
&
spec
,
const
string
&
normalized_builder_name
)
const
override
{
return
normalized_builder_name
==
"FooComponent"
&&
spec
.
num_actions
()
==
1
;
}
bool
PreferredTo
(
const
Component
&
other
)
const
override
{
return
true
;
}
};
DRAGNN_RUNTIME_REGISTER_COMPONENT
(
SpecializedFooComponent
);
TEST
(
ComponentTest
,
NameResolutionError
)
{
ComponentSpec
component_spec
;
EXPECT_DEATH
(
GetNormalizedComponentBuilderName
(
component_spec
),
"No builder name for component spec"
);
}
// Tests that Python-esque module specifiers for builders are normalized
// appropriately.
TEST
(
ComponentTest
,
VariantsOfComponentBuilderNameResolve
)
{
for
(
const
string
&
registered_name
:
{
"FooComponent"
,
"FooComponentBuilder"
,
"module.FooComponent"
,
"module.FooComponentBuilder"
,
"some.long.path.to.module.FooComponent"
,
"some.long.path.to.module.FooComponentBuilder"
})
{
ComponentSpec
component_spec
;
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
registered_name
);
string
result
;
TF_ASSERT_OK
(
Component
::
Select
(
component_spec
,
&
result
));
EXPECT_EQ
(
result
,
"FooComponent"
);
}
}
TEST
(
ComponentTest
,
ErrorWithBothPreferred
)
{
ComponentSpec
component_spec
;
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"ImTheBest"
);
string
result
;
EXPECT_THAT
(
Component
::
Select
(
component_spec
,
&
result
),
test
::
IsErrorWithCodeAndSubstr
(
tensorflow
::
error
::
FAILED_PRECONDITION
,
"Classes 'ImTheBest2' and 'ImTheBest1' "
"both think they should be preferred to "
"each-other. Please add logic to their "
"PreferredTo() methods to avoid this."
));
}
TEST
(
ComponentTest
,
ErrorWithNeitherPreferred
)
{
ComponentSpec
component_spec
;
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"ImTheWorst"
);
string
result
;
EXPECT_THAT
(
Component
::
Select
(
component_spec
,
&
result
),
test
::
IsErrorWithCodeAndSubstr
(
tensorflow
::
error
::
FAILED_PRECONDITION
,
"Classes 'ImTheWorst2' and 'ImTheWorst1' both think they "
"should be dis-preferred to each-other. Please add logic to "
"their PreferredTo() methods to avoid this."
));
}
TEST
(
ComponentTest
,
DefaultComponent
)
{
ComponentSpec
component_spec
;
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"FooComponent"
);
component_spec
.
set_num_actions
(
45
);
string
result
;
TF_EXPECT_OK
(
Component
::
Select
(
component_spec
,
&
result
));
EXPECT_EQ
(
result
,
"FooComponent"
);
}
TEST
(
ComponentTest
,
SpecializedComponent
)
{
ComponentSpec
component_spec
;
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"FooComponent"
);
component_spec
.
set_num_actions
(
1
);
string
result
;
TF_EXPECT_OK
(
Component
::
Select
(
component_spec
,
&
result
));
EXPECT_EQ
(
result
,
"SpecializedFooComponent"
);
}
// Tests that Select() returns NOT_FOUND when there is no matching component.
TEST
(
ComponentTest
,
NoMatchingComponentNotFound
)
{
ComponentSpec
component_spec
;
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"unknown"
);
string
result
;
EXPECT_THAT
(
Component
::
Select
(
component_spec
,
&
result
),
test
::
IsErrorWithCodeAndSubstr
(
tensorflow
::
error
::
NOT_FOUND
,
"Could not find a best spec for component"
));
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/component_transformation.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/component_transformation.h"
#include <memory>
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "dragnn/runtime/component.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
tensorflow
::
Status
TransformComponents
(
const
string
&
input_master_spec_path
,
const
string
&
output_master_spec_path
)
{
MasterSpec
master_spec
;
TF_RETURN_IF_ERROR
(
tensorflow
::
ReadTextProto
(
tensorflow
::
Env
::
Default
(),
input_master_spec_path
,
&
master_spec
));
for
(
ComponentSpec
&
component_spec
:
*
master_spec
.
mutable_component
())
{
TF_RETURN_IF_ERROR
(
ComponentTransformer
::
ApplyAll
(
&
component_spec
));
}
return
tensorflow
::
WriteTextProto
(
tensorflow
::
Env
::
Default
(),
output_master_spec_path
,
master_spec
);
}
tensorflow
::
Status
ComponentTransformer
::
ApplyAll
(
ComponentSpec
*
component_spec
)
{
// Limit on the number of iterations, to prevent infinite loops.
static
constexpr
int
kMaxNumIterations
=
1000
;
std
::
set
<
string
>
names
;
// sorted for determinism
for
(
const
Registry
::
Registrar
*
registrar
=
registry
()
->
components
;
registrar
!=
nullptr
;
registrar
=
registrar
->
next
())
{
names
.
insert
(
registrar
->
name
());
}
std
::
vector
<
std
::
unique_ptr
<
ComponentTransformer
>>
transformers
;
transformers
.
reserve
(
names
.
size
());
for
(
const
string
&
name
:
names
)
transformers
.
emplace_back
(
Create
(
name
));
ComponentSpec
local_spec
=
*
component_spec
;
// avoid modification on error
for
(
int
iteration
=
0
;
iteration
<
kMaxNumIterations
;
++
iteration
)
{
const
ComponentSpec
original_spec
=
local_spec
;
for
(
const
auto
&
transformer
:
transformers
)
{
const
string
component_type
=
GetNormalizedComponentBuilderName
(
local_spec
);
TF_RETURN_IF_ERROR
(
transformer
->
Transform
(
component_type
,
&
local_spec
));
}
if
(
tensorflow
::
protobuf
::
util
::
MessageDifferencer
::
Equals
(
local_spec
,
original_spec
))
{
// Converged successfully; make modifications.
*
component_spec
=
local_spec
;
return
tensorflow
::
Status
::
OK
();
}
}
return
tensorflow
::
errors
::
Internal
(
"Failed to converge within "
,
kMaxNumIterations
,
" ComponentTransformer iterations"
);
}
}
// namespace runtime
}
// namespace dragnn
REGISTER_SYNTAXNET_CLASS_REGISTRY
(
"DRAGNN Runtime Component Transformer"
,
dragnn
::
runtime
::
ComponentTransformer
);
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/component_transformation.h
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Utils for transforming ComponentSpecs, typically (but not necessarily) in
// ways that are intended to improve speed. For example, a transformer might
// detect a favorable component configuration and replace a generic Component
// implementation with a faster version.
#ifndef DRAGNN_RUNTIME_COMPONENT_TRANSFORMATION_H_
#define DRAGNN_RUNTIME_COMPONENT_TRANSFORMATION_H_
#include <string>
#include "dragnn/protos/spec.pb.h"
#include "syntaxnet/base.h"
#include "syntaxnet/registry.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// Loads a MasterSpec from the |input_master_spec_path|, applies all registered
// ComponentTransformers to it (see ComponentTransformer::ApplyAll() below), and
// writes it to the |output_master_spec_path|. On error, returns non-OK.
//
// Side note: This function has a file-path-based API so it can be easily
// wrapped in a stand-alone binary.
tensorflow
::
Status
TransformComponents
(
const
string
&
input_master_spec_path
,
const
string
&
output_master_spec_path
);
// Interface for modules that can transform a ComponentSpec, which allows
// transformations to be plugged in on a decentralized basis.
class
ComponentTransformer
:
public
RegisterableClass
<
ComponentTransformer
>
{
public:
ComponentTransformer
(
const
ComponentTransformer
&
that
)
=
delete
;
ComponentTransformer
&
operator
=
(
const
ComponentTransformer
&
that
)
=
delete
;
virtual
~
ComponentTransformer
()
=
default
;
// Repeatedly loops through all registered transformers and applies them to
// the |component_spec| until no more changes occur. For determinism, each
// loop applies the transformers in ascending order of registered name. On
// error, returns non-OK and modifies nothing.
static
tensorflow
::
Status
ApplyAll
(
ComponentSpec
*
component_spec
);
protected:
ComponentTransformer
()
=
default
;
private:
// Helps prevent use of the Create() method.
using
RegisterableClass
<
ComponentTransformer
>::
Create
;
// Modifies the |component_spec|, which is currently configured to use the
// |component_type|, if compatible. On error, returns non-OK and modifies
// nothing. Note that it is not an error if the |component_spec| is simply
// not compatible with the desired transformation.
virtual
tensorflow
::
Status
Transform
(
const
string
&
component_type
,
ComponentSpec
*
component_spec
)
=
0
;
};
}
// namespace runtime
}
// namespace dragnn
DECLARE_SYNTAXNET_CLASS_REGISTRY
(
"DRAGNN Runtime Component Transformer"
,
dragnn
::
runtime
::
ComponentTransformer
);
}
// namespace syntaxnet
#define DRAGNN_RUNTIME_REGISTER_COMPONENT_TRANSFORMER(subclass) \
REGISTER_SYNTAXNET_CLASS_COMPONENT( \
::syntaxnet::dragnn::runtime::ComponentTransformer, #subclass, subclass)
#endif // DRAGNN_RUNTIME_COMPONENT_TRANSFORMATION_H_
research/syntaxnet/dragnn/runtime/component_transformation_test.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/component_transformation.h"
#include <string>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Transformer that fails if the component type is "fail".
class
MaybeFail
:
public
ComponentTransformer
{
public:
// Implements ComponentTransformer.
tensorflow
::
Status
Transform
(
const
string
&
component_type
,
ComponentSpec
*
)
override
{
if
(
component_type
==
"fail"
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Boom!"
);
}
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_COMPONENT_TRANSFORMER
(
MaybeFail
);
// Base class for transformers that change the name of the component, based on
// its current name.
class
ChangeNameBase
:
public
ComponentTransformer
{
public:
// Creates a transformer that changes the component name from |from| to |to|.
explicit
ChangeNameBase
(
const
string
&
from
,
const
string
&
to
)
:
from_
(
from
),
to_
(
to
)
{}
// Implements ComponentTransformer.
tensorflow
::
Status
Transform
(
const
string
&
,
ComponentSpec
*
component_spec
)
override
{
if
(
component_spec
->
name
()
==
from_
)
component_spec
->
set_name
(
to_
);
return
tensorflow
::
Status
::
OK
();
}
private:
// Component name to look for.
const
string
from_
;
// Component name to change to.
const
string
to_
;
};
// These will convert chain1 => chain2 => chain3.
class
Chain1To2
:
public
ChangeNameBase
{
public:
Chain1To2
()
:
ChangeNameBase
(
"chain1"
,
"chain2"
)
{}
};
class
Chain2To3
:
public
ChangeNameBase
{
public:
Chain2To3
()
:
ChangeNameBase
(
"chain2"
,
"chain3"
)
{}
};
// Adds "." to the name of the component, if it begins with "cycle".
class
Cycle
:
public
ComponentTransformer
{
public:
// Implements ComponentTransformer.
tensorflow
::
Status
Transform
(
const
string
&
,
ComponentSpec
*
component_spec
)
override
{
if
(
component_spec
->
name
().
substr
(
0
,
5
)
==
"cycle"
)
{
component_spec
->
mutable_name
()
->
append
(
"."
);
}
return
tensorflow
::
Status
::
OK
();
}
};
// Intentionally registered out of order to exercise sorting on registered name.
DRAGNN_RUNTIME_REGISTER_COMPONENT_TRANSFORMER
(
Chain1To2
);
DRAGNN_RUNTIME_REGISTER_COMPONENT_TRANSFORMER
(
Chain2To3
);
DRAGNN_RUNTIME_REGISTER_COMPONENT_TRANSFORMER
(
Cycle
);
// Arbitrary bogus path.
constexpr
char
kInvalidPath
[]
=
"path/to/some/invalid/file"
;
// Returns a unique temporary directory for tests.
string
GetUniqueTemporaryDir
()
{
static
int
counter
=
0
;
const
string
output_dir
=
tensorflow
::
io
::
JoinPath
(
tensorflow
::
testing
::
TmpDir
(),
tensorflow
::
strings
::
StrCat
(
"tmp_"
,
counter
++
));
TF_CHECK_OK
(
tensorflow
::
Env
::
Default
()
->
RecursivelyCreateDir
(
output_dir
));
return
output_dir
;
}
// Returns a MasterSpec parsed from the |text|.
MasterSpec
ParseSpec
(
const
string
&
text
)
{
MasterSpec
master_spec
;
CHECK
(
TextFormat
::
ParseFromString
(
text
,
&
master_spec
));
return
master_spec
;
}
// Tests that TransformComponents() fails if the input master spec path is
// invalid.
TEST
(
TransformComponentsTest
,
InvalidInputMasterSpecPath
)
{
const
string
temp_dir
=
GetUniqueTemporaryDir
();
const
string
output_path
=
tensorflow
::
io
::
JoinPath
(
temp_dir
,
"output"
);
EXPECT_FALSE
(
TransformComponents
(
kInvalidPath
,
output_path
).
ok
());
}
// Tests that TransformComponents() fails if the output master spec path is
// invalid.
TEST
(
TransformComponentsTest
,
InvalidOutputMasterSpecPath
)
{
const
string
temp_dir
=
GetUniqueTemporaryDir
();
const
string
input_path
=
tensorflow
::
io
::
JoinPath
(
temp_dir
,
"input"
);
const
MasterSpec
empty_spec
;
TF_ASSERT_OK
(
tensorflow
::
WriteTextProto
(
tensorflow
::
Env
::
Default
(),
input_path
,
empty_spec
));
EXPECT_FALSE
(
TransformComponents
(
input_path
,
kInvalidPath
).
ok
());
}
// Tests that TransformComponents() fails if one of the ComponentTransformers
// fails.
TEST
(
TransformComponentsTest
,
FailingComponentTransformer
)
{
const
string
temp_dir
=
GetUniqueTemporaryDir
();
const
string
input_path
=
tensorflow
::
io
::
JoinPath
(
temp_dir
,
"input"
);
const
string
output_path
=
tensorflow
::
io
::
JoinPath
(
temp_dir
,
"output"
);
const
MasterSpec
input_spec
=
ParseSpec
(
R"(
component {
component_builder { registered_name:'foo' }
}
component {
component_builder { registered_name:'fail' }
}
)"
);
TF_ASSERT_OK
(
tensorflow
::
WriteTextProto
(
tensorflow
::
Env
::
Default
(),
input_path
,
input_spec
));
EXPECT_THAT
(
TransformComponents
(
input_path
,
output_path
),
test
::
IsErrorWithSubstr
(
"Boom!"
));
}
// Tests that TransformComponents() properly applies all transformations.
TEST
(
TransformComponentsTest
,
Success
)
{
const
string
temp_dir
=
GetUniqueTemporaryDir
();
const
string
input_path
=
tensorflow
::
io
::
JoinPath
(
temp_dir
,
"input"
);
const
string
output_path
=
tensorflow
::
io
::
JoinPath
(
temp_dir
,
"output"
);
const
MasterSpec
input_spec
=
ParseSpec
(
R"(
component {
name:'chain1'
component_builder { registered_name:'foo' }
}
component {
name:'irrelevant'
component_builder { registered_name:'bar' }
}
)"
);
TF_ASSERT_OK
(
tensorflow
::
WriteTextProto
(
tensorflow
::
Env
::
Default
(),
input_path
,
input_spec
));
TF_ASSERT_OK
(
TransformComponents
(
input_path
,
output_path
));
MasterSpec
actual_spec
;
TF_ASSERT_OK
(
tensorflow
::
ReadTextProto
(
tensorflow
::
Env
::
Default
(),
output_path
,
&
actual_spec
));
const
MasterSpec
expected_spec
=
ParseSpec
(
R"(
component {
name:'chain3'
component_builder { registered_name:'foo' }
}
component {
name:'irrelevant'
component_builder { registered_name:'bar' }
}
)"
);
EXPECT_THAT
(
actual_spec
,
test
::
EqualsProto
(
expected_spec
));
}
// Tests that ComponentTransformer::ApplyAll() makes the expected modifications,
// including chained modifications.
TEST
(
ComponentTransformerTest
,
ApplyAllSuccess
)
{
ComponentSpec
component_spec
;
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"foo"
);
component_spec
.
set_name
(
"chain1"
);
ComponentSpec
modified_spec
=
component_spec
;
TF_ASSERT_OK
(
ComponentTransformer
::
ApplyAll
(
&
component_spec
));
modified_spec
.
set_name
(
"chain3"
);
EXPECT_THAT
(
component_spec
,
test
::
EqualsProto
(
modified_spec
));
}
// Tests that ComponentTransformer::ApplyAll() limits the number of iterations.
TEST
(
ComponentTransformerTest
,
ApplyAllLimitIterations
)
{
ComponentSpec
component_spec
;
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"foo"
);
component_spec
.
set_name
(
"cycle"
);
EXPECT_THAT
(
ComponentTransformer
::
ApplyAll
(
&
component_spec
),
test
::
IsErrorWithSubstr
(
"Failed to converge"
));
}
// Tests that ComponentTransformer::ApplyAll() fails if one of the
// ComponentTransformers fails.
TEST
(
ComponentTransformerTest
,
ApplyAllFailure
)
{
ComponentSpec
component_spec
;
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"fail"
);
EXPECT_THAT
(
ComponentTransformer
::
ApplyAll
(
&
component_spec
),
test
::
IsErrorWithSubstr
(
"Boom!"
));
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/conversion.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/conversion.h"
#include <memory>
#include <utility>
#include "dragnn/protos/runtime.pb.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/array_variable_store_builder.h"
#include "dragnn/runtime/master.h"
#include "dragnn/runtime/trained_model_variable_store.h"
#include "dragnn/runtime/variable_store.h"
#include "dragnn/runtime/variable_store_wrappers.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/env.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
tensorflow
::
Status
ConvertVariables
(
const
string
&
saved_model_dir
,
const
string
&
master_spec_path
,
const
string
&
variables_spec_path
,
const
string
&
variables_data_path
)
{
// Read the trained model.
auto
*
trained_model_store
=
new
TrainedModelVariableStore
();
std
::
unique_ptr
<
VariableStore
>
store
(
trained_model_store
);
TF_RETURN_IF_ERROR
(
trained_model_store
->
Reset
(
saved_model_dir
));
// Wrap the TF store to enable averaging and capturing.
//
// The averaging wrapper currently needs to allow fall-back versions, since
// derived parameters (used for the LSTM network) read averaged versions via
// their TensorFlow-internal logic.
//
// The capturing wrapper must be the outermost, so variable names, formats,
// and content are captured exactly as the components would receive them.
store
.
reset
(
new
TryAveragedVariableStoreWrapper
(
std
::
move
(
store
),
true
));
store
.
reset
(
new
FlexibleMatrixVariableStoreWrapper
(
std
::
move
(
store
)));
auto
*
capturing_store
=
new
CaptureUsedVariableStoreWrapper
(
std
::
move
(
store
));
store
.
reset
(
capturing_store
);
// Initialize a master using the wrapped store. This should populate the
// |capturing_store| with all of the used variables.
MasterSpec
master_spec
;
TF_RETURN_IF_ERROR
(
tensorflow
::
ReadTextProto
(
tensorflow
::
Env
::
Default
(),
master_spec_path
,
&
master_spec
));
Master
master
;
TF_RETURN_IF_ERROR
(
master
.
Initialize
(
master_spec
,
std
::
move
(
store
)));
// Convert the used variables into an ArrayVariableStore.
ArrayVariableStoreSpec
variables_spec
;
string
variables_data
;
TF_RETURN_IF_ERROR
(
ArrayVariableStoreBuilder
::
Build
(
capturing_store
->
variables
(),
&
variables_spec
,
&
variables_data
));
// Write the converted variables.
TF_RETURN_IF_ERROR
(
tensorflow
::
WriteTextProto
(
tensorflow
::
Env
::
Default
(),
variables_spec_path
,
variables_spec
));
TF_RETURN_IF_ERROR
(
tensorflow
::
WriteStringToFile
(
tensorflow
::
Env
::
Default
(),
variables_data_path
,
variables_data
));
return
tensorflow
::
Status
::
OK
();
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/conversion.h
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Utils for converting pre-trained models into a production-ready format.
#ifndef DRAGNN_RUNTIME_CONVERSION_H_
#define DRAGNN_RUNTIME_CONVERSION_H_
#include <string>
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// Converts selected variables from a pre-trained TF model into the format used
// by the ArrayVariableStore. Only converts the variables required to run the
// components in a given MasterSpec.
//
// Inputs:
// saved_model_dir: TF SavedModel directory.
// master_spec_path: Text-format MasterSpec proto.
//
// Outputs:
// variables_spec_path: Text-format ArrayVariableStoreSpec proto.
// variables_data_path: Byte array representing an ArrayVariableStore.
//
// This function will instantiate and initialize a Master using the MasterSpec
// at the |master_path|, so any registered components used by that MasterSpec
// must be linked into the binary.
//
// Side note: This function has a file-path-based API so it can be easily
// wrapped in a stand-alone binary.
tensorflow
::
Status
ConvertVariables
(
const
string
&
saved_model_dir
,
const
string
&
master_spec_path
,
const
string
&
variables_spec_path
,
const
string
&
variables_data_path
);
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_CONVERSION_H_
research/syntaxnet/dragnn/runtime/conversion_test.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/conversion.h"
#include <string>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/runtime.pb.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
class
ConvertVariablesTest
:
public
::
testing
::
Test
{
protected:
// The input files.
const
string
kSavedModelDir
=
tensorflow
::
io
::
JoinPath
(
test
::
GetTestDataPrefix
(),
"dragnn/runtime/testdata/rnn_tagger"
);
const
string
kMasterSpecPath
=
tensorflow
::
io
::
JoinPath
(
test
::
GetTestDataPrefix
(),
"dragnn/runtime/testdata/rnn_tagger/assets.extra/master_spec"
);
// Writable output files.
const
string
kVariablesSpecPath
=
tensorflow
::
io
::
JoinPath
(
tensorflow
::
testing
::
TmpDir
(),
"variables_spec"
);
const
string
kVariablesDataPath
=
tensorflow
::
io
::
JoinPath
(
tensorflow
::
testing
::
TmpDir
(),
"variables_data"
);
// Bogus file for tests.
const
string
kInvalidPath
=
"path/to/some/invalid/file"
;
// Expected output files.
const
string
kExpectedVariablesSpecPath
=
tensorflow
::
io
::
JoinPath
(
test
::
GetTestDataPrefix
(),
"dragnn/runtime/testdata/conversion_output_variables_spec"
);
const
string
kExpectedVariablesDataPath
=
tensorflow
::
io
::
JoinPath
(
test
::
GetTestDataPrefix
(),
"dragnn/runtime/testdata/conversion_output_variables_data"
);
// Local relative paths to the output files.
const
string
kLocalVariablesSpecPath
=
"dragnn/runtime/testdata/"
"conversion_output_variables_spec"
;
const
string
kLocalVariablesDataPath
=
"dragnn/runtime/testdata/"
"conversion_output_variables_data"
;
};
// Tests that the conversion fails if the saved model is invalid.
TEST_F
(
ConvertVariablesTest
,
InvalidSavedModel
)
{
EXPECT_FALSE
(
ConvertVariables
(
kInvalidPath
,
kMasterSpecPath
,
kVariablesSpecPath
,
kVariablesDataPath
)
.
ok
());
}
// Tests that the conversion fails if the master spec is invalid.
TEST_F
(
ConvertVariablesTest
,
InvalidMasterSpec
)
{
EXPECT_FALSE
(
ConvertVariables
(
kSavedModelDir
,
kInvalidPath
,
kVariablesSpecPath
,
kVariablesDataPath
)
.
ok
());
}
// Tests that the conversion fails if the variables spec is invalid.
TEST_F
(
ConvertVariablesTest
,
InvalidVariablesSpec
)
{
EXPECT_FALSE
(
ConvertVariables
(
kSavedModelDir
,
kMasterSpecPath
,
kInvalidPath
,
kVariablesDataPath
)
.
ok
());
}
// Tests that the conversion fails if the variables data is invalid.
TEST_F
(
ConvertVariablesTest
,
InvalidVariablesData
)
{
EXPECT_FALSE
(
ConvertVariables
(
kSavedModelDir
,
kMasterSpecPath
,
kVariablesSpecPath
,
kInvalidPath
)
.
ok
());
}
// Tests that the conversion succeeds on the pre-trained inputs and reproduces
// expected outputs.
TEST_F
(
ConvertVariablesTest
,
RegressionTest
)
{
TF_EXPECT_OK
(
ConvertVariables
(
kSavedModelDir
,
kMasterSpecPath
,
kVariablesSpecPath
,
kVariablesDataPath
));
ArrayVariableStoreSpec
actual_variables_spec
;
string
actual_variables_data
;
TF_ASSERT_OK
(
tensorflow
::
ReadTextProto
(
tensorflow
::
Env
::
Default
(),
kVariablesSpecPath
,
&
actual_variables_spec
));
TF_ASSERT_OK
(
tensorflow
::
ReadFileToString
(
tensorflow
::
Env
::
Default
(),
kVariablesDataPath
,
&
actual_variables_data
));
if
(
false
)
{
TF_ASSERT_OK
(
tensorflow
::
WriteTextProto
(
tensorflow
::
Env
::
Default
(),
kLocalVariablesSpecPath
,
actual_variables_spec
));
TF_ASSERT_OK
(
tensorflow
::
WriteStringToFile
(
tensorflow
::
Env
::
Default
(),
kLocalVariablesDataPath
,
actual_variables_data
));
}
else
{
ArrayVariableStoreSpec
expected_variables_spec
;
string
expected_variables_data
;
TF_ASSERT_OK
(
tensorflow
::
ReadTextProto
(
tensorflow
::
Env
::
Default
(),
kExpectedVariablesSpecPath
,
&
expected_variables_spec
));
TF_ASSERT_OK
(
tensorflow
::
ReadFileToString
(
tensorflow
::
Env
::
Default
(),
kExpectedVariablesDataPath
,
&
expected_variables_data
));
EXPECT_THAT
(
actual_variables_spec
,
test
::
EqualsProto
(
expected_variables_spec
));
EXPECT_EQ
(
actual_variables_data
,
expected_variables_data
);
}
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/converter.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Tool for converting trained models for use in the runtime.
#include <set>
#include <string>
#include <vector>
#include "dragnn/runtime/component_transformation.h"
#include "dragnn/runtime/conversion.h"
#include "dragnn/runtime/myelin/myelination.h"
#include "dragnn/runtime/xla/xla_compilation.h"
#include "syntaxnet/base.h"
#include "sling/base/flags.h" // TF does not support flags, but SLING does
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
DEFINE_string
(
saved_model_dir
,
""
,
"Path to TF SavedModel directory."
);
DEFINE_string
(
master_spec_file
,
""
,
"Path to text-format MasterSpec proto."
);
DEFINE_string
(
myelin_components
,
""
,
"Comma-delimited list of components to compile using Myelin, if any"
);
DEFINE_string
(
xla_components
,
""
,
"Comma-delimited list of components to compile using XLA, if any."
);
DEFINE_string
(
xla_model_name
,
""
,
"Name to apply to XLA-based components."
);
DEFINE_string
(
output_dir
,
""
,
"Path to an output directory. This will be filled with the following "
"files and subdirectories. MasterSpec: Converted text-format MasterSpec "
"proto. ArrayVariableStoreSpec: Converted text-format variable spec. "
"ArrayVariableStoreData: Converted variable data. myelin/*: Compiled "
"Myelin components, if any. xla/*: Compiled XLA components, if any."
);
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Splits the |list| on commas and returns the set of elements.
std
::
set
<
string
>
Split
(
const
string
&
list
)
{
const
std
::
vector
<
string
>
elements
=
tensorflow
::
str_util
::
Split
(
list
,
","
);
return
std
::
set
<
string
>
(
elements
.
begin
(),
elements
.
end
());
}
// Creates an empty directory at the |path|. If the directory exists, it is
// recursively deleted first.
void
CreateEmptyDir
(
const
string
&
path
)
{
// Ensure that the directory exists; otherwise DeleteRecursively() may fail.
TF_QCHECK_OK
(
tensorflow
::
Env
::
Default
()
->
RecursivelyCreateDir
(
path
));
int64
unused_undeleted_files
,
unused_undeleted_dirs
;
TF_QCHECK_OK
(
tensorflow
::
Env
::
Default
()
->
DeleteRecursively
(
path
,
&
unused_undeleted_files
,
&
unused_undeleted_dirs
));
TF_QCHECK_OK
(
tensorflow
::
Env
::
Default
()
->
RecursivelyCreateDir
(
path
));
}
// Performs Myelin compilation on the MasterSpec at |master_spec_path|, if
// requested. Returns the path to the converted or original MasterSpec.
string
CompileMyelin
(
const
string
&
master_spec_path
)
{
const
std
::
set
<
string
>
components
=
Split
(
FLAGS_myelin_components
);
if
(
components
.
empty
())
return
master_spec_path
;
LOG
(
INFO
)
<<
"Compiling Myelin in MasterSpec "
<<
master_spec_path
;
const
string
dir
=
tensorflow
::
io
::
JoinPath
(
FLAGS_output_dir
,
"myelin"
);
CreateEmptyDir
(
dir
);
TF_QCHECK_OK
(
MyelinateCells
(
FLAGS_saved_model_dir
,
master_spec_path
,
components
,
dir
));
return
tensorflow
::
io
::
JoinPath
(
dir
,
"master-spec"
);
}
// Performs XLA compilation on the MasterSpec at |master_spec_path|, if
// requested. Returns the path to the converted or original MasterSpec.
string
CompileXla
(
const
string
&
master_spec_path
)
{
const
std
::
set
<
string
>
components
=
Split
(
FLAGS_xla_components
);
if
(
components
.
empty
())
return
master_spec_path
;
LOG
(
INFO
)
<<
"Compiling XLA in MasterSpec "
<<
master_spec_path
;
const
string
dir
=
tensorflow
::
io
::
JoinPath
(
FLAGS_output_dir
,
"xla"
);
CreateEmptyDir
(
dir
);
TF_QCHECK_OK
(
XlaCompileCells
(
FLAGS_saved_model_dir
,
master_spec_path
,
components
,
FLAGS_xla_model_name
,
dir
));
return
tensorflow
::
io
::
JoinPath
(
dir
,
"master-spec"
);
}
// Transforms the MasterSpec at |master_spec_path|, and returns the path to the
// transformed MasterSpec.
string
Transform
(
const
string
&
master_spec_path
)
{
LOG
(
INFO
)
<<
"Transforming MasterSpec "
<<
master_spec_path
;
const
string
output_master_spec_path
=
tensorflow
::
io
::
JoinPath
(
FLAGS_output_dir
,
"MasterSpec"
);
TF_QCHECK_OK
(
TransformComponents
(
master_spec_path
,
output_master_spec_path
));
return
output_master_spec_path
;
}
// Performs final variable conversion on the MasterSpec at |master_spec_path|.
void
Convert
(
const
string
&
master_spec_path
)
{
LOG
(
INFO
)
<<
"Converting MasterSpec "
<<
master_spec_path
;
const
string
variables_data_path
=
tensorflow
::
io
::
JoinPath
(
FLAGS_output_dir
,
"ArrayVariableStoreData"
);
const
string
variables_spec_path
=
tensorflow
::
io
::
JoinPath
(
FLAGS_output_dir
,
"ArrayVariableStoreSpec"
);
TF_QCHECK_OK
(
ConvertVariables
(
FLAGS_saved_model_dir
,
master_spec_path
,
variables_spec_path
,
variables_data_path
));
}
// Implements main().
void
Main
()
{
CreateEmptyDir
(
FLAGS_output_dir
);
string
master_spec_path
=
FLAGS_master_spec_file
;
master_spec_path
=
CompileMyelin
(
master_spec_path
);
master_spec_path
=
CompileXla
(
master_spec_path
);
master_spec_path
=
Transform
(
master_spec_path
);
Convert
(
master_spec_path
);
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
int
main
(
int
argc
,
char
**
argv
)
{
sling
::
Flag
::
ParseCommandLineFlags
(
&
argc
,
argv
,
true
);
syntaxnet
::
dragnn
::
runtime
::
Main
();
return
0
;
}
research/syntaxnet/dragnn/runtime/converter_test.sh
0 → 100755
View file @
a4bb31d0
#!/bin/bash
# Copyright 2018 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Test for converter tool. To update the testdata, run the test with a single
# command-line argument specifying the path to the testdata directory.
set
-e
set
-u
# Infer the location of the data dependencies.
if
[[
-d
"
${
BASH_SOURCE
[0]
}
.runfiles"
]]
;
then
# Use the ".runfiles" directory if available (this typically happens when
# running manually). SyntaxNet does not specify a workspace name, so the
# runfiles are placed in ".runfiles/__main__". If SyntaxNet is configured
# with a workspace name, then change "__main__" to that name. See
# https://github.com/bazelbuild/bazel/wiki/Updating-the-runfiles-tree-structure
RUNFILES
=
"
${
BASH_SOURCE
[0]
}
.runfiles/__main__"
else
# Otherwise, use this recipe borrowed from
# https://github.com/bazelbuild/bazel/blob/7d265e07e7a1e37f04d53342710e4f21d9ee8083/examples/shell/test.sh#L21
# shellcheck disable=SC2091
RUNFILES
=
"
${
RUNFILES
:-
"
$(
"
$(
cd
"
$(
dirname
"
${
BASH_SOURCE
[0]
}
"
)
"
)
"
;
pwd
)
"
}
"
fi
readonly
RUNFILES
readonly
RUNTIME
=
"
${
RUNFILES
}
/dragnn/runtime"
readonly
CONVERTER
=
"
${
RUNTIME
}
/converter"
readonly
SAVED_MODEL
=
"
${
RUNTIME
}
/testdata/rnn_tagger"
readonly
MASTER_SPEC
=
"
${
SAVED_MODEL
}
/assets.extra/master_spec"
readonly
EXPECTED
=
"
${
RUNTIME
}
/testdata/converter_output"
readonly
OUTPUT
=
"
${
TEST_TMPDIR
:-
/tmp/
$$
}
/converted"
# Fails the test with a message.
function
fail
()
{
echo
"
$@
"
1>&2
# print to stderr
exit
1
}
# Asserts that a file exists.
function
assert_file_exists
()
{
if
[[
!
-f
"
$1
"
]]
;
then
fail
"missing file:
$1
"
fi
}
# Asserts that two files have the same content.
function
assert_file_content_eq
()
{
assert_file_exists
"
$1
"
assert_file_exists
"
$2
"
if
!
diff
-u
"
$1
"
"
$2
"
;
then
fail
"files differ:
$1
$2
"
fi
}
rm
-rf
"
${
OUTPUT
}
"
"
${
CONVERTER
}
"
\
--saved_model_dir
=
"
${
SAVED_MODEL
}
"
\
--master_spec_file
=
"
${
MASTER_SPEC
}
"
\
--output_dir
=
"
${
OUTPUT
}
"
\
--logtostderr
for
file
in
\
'MasterSpec'
\
'ArrayVariableStoreData'
\
'ArrayVariableStoreSpec'
;
do
if
[[
$#
-gt
0
]]
;
then
# Update expected output.
rm
-f
"
$1
/
${
file
}
"
cp
-f
"
${
OUTPUT
}
/
${
file
}
"
"
$1
/
${
file
}
"
else
# Compare to expected output.
assert_file_content_eq
"
${
OUTPUT
}
/
${
file
}
"
"
${
EXPECTED
}
/
${
file
}
"
fi
done
rm
-rf
"
${
OUTPUT
}
"
research/syntaxnet/dragnn/runtime/dynamic_component.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// The DynamicComponent is the runtime analogue of the DynamicComponentBuilder
// in the Python codebase. The role of the DynamicComponent is to manage the
// loop over transition steps, including:
// * Allocating stepwise memory for network states and operands.
// * Performing some computation at each step.
// * Advancing the transition state until terminal.
//
// Note that the number of transition taken on any given evaluation of the
// DynamicComponent cannot be determined in advance.
//
// The core computational work is delegated to a NetworkUnit, which is evaluated
// at each transition step. This makes the DynamicComponent flexible, since it
// can be applied to any NetworkUnit implementation, but it can be significantly
// more efficient to use a task-specific component implementation. For example,
// the "shift-only" transition system merely scans the input tokens, and in that
// case we could replace the incremental loop with a "bulk" computation.
#include <stddef.h>
#include <string>
#include <vector>
#include "dragnn/core/compute_session.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/network_unit.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/transition_system_traits.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Performs an incremental computation, one transition at a time.
class
DynamicComponent
:
public
Component
{
protected:
// Implements Component.
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
override
;
tensorflow
::
Status
Evaluate
(
SessionState
*
session_state
,
ComputeSession
*
compute_session
,
ComponentTrace
*
component_trace
)
const
override
;
// This class is intended to support all DynamicComponent layers. We currently
// prefer to return `true` here and throw errors in Initialize() if a
// particular feature is not supported.
bool
Supports
(
const
ComponentSpec
&
spec
,
const
string
&
normalized_builder_name
)
const
override
{
return
normalized_builder_name
==
"DynamicComponent"
;
}
// This class is not optimized, so any other supported subclasses of Component
// should be preferred.
bool
PreferredTo
(
const
Component
&
other
)
const
override
{
return
false
;
}
private:
// Name of this component.
string
name_
;
// Network unit that produces logits.
std
::
unique_ptr
<
NetworkUnit
>
network_unit_
;
// Whether the transition system is deterministic.
bool
deterministic_
=
false
;
// Handle to the network unit logits. Valid iff |deterministic_| is false.
LayerHandle
<
float
>
logits_handle_
;
};
tensorflow
::
Status
DynamicComponent
::
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
{
name_
=
component_spec
.
name
();
if
(
!
component_spec
.
attention_component
().
empty
())
{
return
tensorflow
::
errors
::
Unimplemented
(
"Attention is not supported"
);
}
TF_RETURN_IF_ERROR
(
NetworkUnit
::
CreateOrError
(
NetworkUnit
::
GetClassName
(
component_spec
),
&
network_unit_
));
TF_RETURN_IF_ERROR
(
network_unit_
->
Initialize
(
component_spec
,
variable_store
,
network_state_manager
,
extension_manager
));
// Logits are unnecesssary when the component is deterministic.
deterministic_
=
TransitionSystemTraits
(
component_spec
).
is_deterministic
;
if
(
!
deterministic_
)
{
const
string
logits_name
=
network_unit_
->
GetLogitsName
();
if
(
logits_name
.
empty
())
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Network unit does not produce logits: "
,
component_spec
.
network_unit
().
ShortDebugString
());
}
size_t
dimension
=
0
;
TF_RETURN_IF_ERROR
(
network_state_manager
->
LookupLayer
(
name_
,
logits_name
,
&
dimension
,
&
logits_handle_
));
if
(
dimension
!=
component_spec
.
num_actions
())
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Dimension mismatch between network unit logits ("
,
dimension
,
") and ComponentSpec.num_actions ("
,
component_spec
.
num_actions
(),
") in component '"
,
name_
,
"'"
);
}
}
return
tensorflow
::
Status
::
OK
();
}
// No batches or beams.
constexpr
int
kNumItems
=
1
;
tensorflow
::
Status
DynamicComponent
::
Evaluate
(
SessionState
*
session_state
,
ComputeSession
*
compute_session
,
ComponentTrace
*
component_trace
)
const
{
NetworkStates
&
network_states
=
session_state
->
network_states
;
for
(
size_t
step_index
=
0
;
!
compute_session
->
IsTerminal
(
name_
);
++
step_index
)
{
network_states
.
AddStep
();
TF_RETURN_IF_ERROR
(
network_unit_
->
Evaluate
(
step_index
,
session_state
,
compute_session
));
// If the component is deterministic, take the oracle transition instead of
// predicting the next transition using the logits.
if
(
deterministic_
)
{
compute_session
->
AdvanceFromOracle
(
name_
);
}
else
{
// AddStep() may invalidate the logits (due to reallocation), so the layer
// lookup cannot be hoisted out of this loop.
const
Vector
<
float
>
logits
(
network_states
.
GetLayer
(
logits_handle_
).
row
(
step_index
));
if
(
!
compute_session
->
AdvanceFromPrediction
(
name_
,
logits
.
data
(),
kNumItems
,
logits
.
size
()))
{
return
tensorflow
::
errors
::
Internal
(
"Error in ComputeSession::AdvanceFromPrediction()"
);
}
}
}
return
tensorflow
::
Status
::
OK
();
}
DRAGNN_RUNTIME_REGISTER_COMPONENT
(
DynamicComponent
);
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/dynamic_component_test.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <algorithm>
#include <limits>
#include <memory>
#include <string>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/network_unit.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include <gmock/gmock.h>
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
using
::
testing
::
_
;
using
::
testing
::
Return
;
constexpr
size_t
kStepsDim
=
41
;
constexpr
size_t
kNumSteps
=
23
;
// Fills each row of its logits with the step index.
class
StepsNetwork
:
public
NetworkUnit
{
public:
// Implements NetworkUnit.
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
override
{
return
network_state_manager
->
AddLayer
(
"steps"
,
kStepsDim
,
&
handle_
);
}
string
GetLogitsName
()
const
override
{
return
"steps"
;
}
tensorflow
::
Status
Evaluate
(
size_t
step_index
,
SessionState
*
session_state
,
ComputeSession
*
compute_session
)
const
override
{
const
MutableVector
<
float
>
logits
=
session_state
->
network_states
.
GetLayer
(
handle_
).
row
(
step_index
);
for
(
float
&
logit
:
logits
)
logit
=
step_index
;
return
tensorflow
::
Status
::
OK
();
}
private:
// Handle to the logits layer.
LayerHandle
<
float
>
handle_
;
};
DRAGNN_RUNTIME_REGISTER_NETWORK_UNIT
(
StepsNetwork
);
// As above, but does not report a logits layer.
class
NoLogitsNetwork
:
public
StepsNetwork
{
public:
// Implements NetworkUnit.
string
GetLogitsName
()
const
override
{
return
""
;
}
};
DRAGNN_RUNTIME_REGISTER_NETWORK_UNIT
(
NoLogitsNetwork
);
class
DynamicComponentTest
:
public
NetworkTestBase
{
protected:
// Creates a component, initializes it based on the |component_spec_text| and
// |network_unit_name|, and evaluates it. On error, returns non-OK.
tensorflow
::
Status
Run
(
const
string
&
component_spec_text
,
const
string
&
network_unit_name
)
{
ComponentSpec
component_spec
;
CHECK
(
TextFormat
::
ParseFromString
(
component_spec_text
,
&
component_spec
));
component_spec
.
set_name
(
kTestComponentName
);
component_spec
.
mutable_network_unit
()
->
set_registered_name
(
network_unit_name
);
// Neither DynamicComponent nor the test networks use linked embeddings, so
// a trivial network suffices.
AddComponent
(
kTestComponentName
);
TF_RETURN_IF_ERROR
(
Component
::
CreateOrError
(
"DynamicComponent"
,
&
component_
));
TF_RETURN_IF_ERROR
(
component_
->
Initialize
(
component_spec
,
&
variable_store_
,
&
network_state_manager_
,
&
extension_manager_
));
network_states_
.
Reset
(
&
network_state_manager_
);
StartComponent
(
0
);
// DynamicComponent will add steps
session_state_
.
extensions
.
Reset
(
&
extension_manager_
);
TF_RETURN_IF_ERROR
(
component_
->
Evaluate
(
&
session_state_
,
&
compute_session_
,
nullptr
));
steps_
=
GetLayer
(
kTestComponentName
,
"steps"
);
return
tensorflow
::
Status
::
OK
();
}
std
::
unique_ptr
<
Component
>
component_
;
Matrix
<
float
>
steps_
;
};
// Tests that DynamicComponent fails if the spec uses attention.
TEST_F
(
DynamicComponentTest
,
UnsupportedAttention
)
{
EXPECT_THAT
(
Run
(
"attention_component: 'foo'"
,
"NoLogitsNetwork"
),
test
::
IsErrorWithSubstr
(
"Attention is not supported"
));
}
// Tests that DynamicComponent fails if the network does not produce logits.
TEST_F
(
DynamicComponentTest
,
NoLogits
)
{
EXPECT_THAT
(
Run
(
""
,
"NoLogitsNetwork"
),
test
::
IsErrorWithSubstr
(
"Network unit does not produce logits"
));
}
// Tests that DynamicComponent fails if the logits do not have the required
// dimension.
TEST_F
(
DynamicComponentTest
,
MismatchedLogitsDimension
)
{
EXPECT_THAT
(
Run
(
"num_actions: 42"
,
"StepsNetwork"
),
test
::
IsErrorWithSubstr
(
"Dimension mismatch between network unit logits "
"(41) and ComponentSpec.num_actions (42)"
));
}
// Tests that DynamicComponent fails if ComputeSession::AdvanceFromPrediction()
// returns false.
TEST_F
(
DynamicComponentTest
,
FailToAdvanceFromPrediction
)
{
EXPECT_CALL
(
compute_session_
,
IsTerminal
(
_
)).
WillRepeatedly
(
Return
(
false
));
EXPECT_CALL
(
compute_session_
,
AdvanceFromPrediction
(
_
,
_
,
_
,
_
))
.
WillOnce
(
Return
(
false
));
EXPECT_THAT
(
Run
(
"num_actions: 41"
,
"StepsNetwork"
),
test
::
IsErrorWithSubstr
(
"Error in ComputeSession::AdvanceFromPrediction()"
));
}
// Tests that DynamicComponent evaluates its network unit once per transition,
// each time passing the proper step index.
TEST_F
(
DynamicComponentTest
,
Steps
)
{
SetupTransitionLoop
(
kNumSteps
);
// Accept |num_steps| transition steps.
EXPECT_CALL
(
compute_session_
,
AdvanceFromPrediction
(
_
,
_
,
_
,
_
))
.
Times
(
kNumSteps
)
.
WillRepeatedly
(
Return
(
true
));
TF_ASSERT_OK
(
Run
(
"num_actions: 41"
,
"StepsNetwork"
));
ASSERT_EQ
(
steps_
.
num_rows
(),
kNumSteps
);
for
(
size_t
step_index
=
0
;
step_index
<
kNumSteps
;
++
step_index
)
{
ExpectVector
(
steps_
.
row
(
step_index
),
kStepsDim
,
step_index
);
}
}
// Tests that DynamicComponent calls ComputeSession::AdvanceFromOracle() and
// does not use logits when the component is deterministic.
TEST_F
(
DynamicComponentTest
,
Determinstic
)
{
SetupTransitionLoop
(
kNumSteps
);
// Take the oracle transition instead of predicting from logits.
EXPECT_CALL
(
compute_session_
,
AdvanceFromOracle
(
_
)).
Times
(
kNumSteps
);
TF_EXPECT_OK
(
Run
(
"num_actions: 1"
,
"NoLogitsNetwork"
));
// The NoLogitsNetwork still produces the "steps" layer, even if it does not
// mark them as its logits.
ASSERT_EQ
(
steps_
.
num_rows
(),
kNumSteps
);
for
(
size_t
step_index
=
0
;
step_index
<
kNumSteps
;
++
step_index
)
{
ExpectVector
(
steps_
.
row
(
step_index
),
kStepsDim
,
step_index
);
}
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/extensions.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/extensions.h"
#include <algorithm>
#include <iterator>
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
void
ExtensionManager
::
GetSharedImpl
(
Deleter
deleter
,
size_t
*
index
)
{
// Look for a matching shared extension.
const
auto
it
=
std
::
find_if
(
configs_
.
begin
(),
configs_
.
end
(),
[
=
](
const
ExtensionConfig
&
config
)
{
return
config
.
is_shared
&&
config
.
deleter
==
deleter
;
});
if
(
it
!=
configs_
.
end
())
{
// found; use its index
*
index
=
std
::
distance
(
configs_
.
begin
(),
it
);
}
else
{
// missing; add it using the next index
*
index
=
configs_
.
size
();
configs_
.
emplace_back
(
/*is_shared=*/
true
,
deleter
);
}
}
void
ExtensionManager
::
AddLocalImpl
(
Deleter
deleter
,
size_t
*
index
)
{
*
index
=
configs_
.
size
();
configs_
.
emplace_back
(
/*is_shared=*/
false
,
deleter
);
}
Extensions
::
Extensions
(
Extensions
&&
that
)
:
manager_
(
that
.
manager_
),
extensions_
(
std
::
move
(
that
.
extensions_
))
{
that
.
manager_
=
nullptr
;
that
.
extensions_
.
clear
();
}
Extensions
&
Extensions
::
operator
=
(
Extensions
&&
that
)
{
Clear
();
manager_
=
that
.
manager_
;
extensions_
=
std
::
move
(
that
.
extensions_
);
that
.
manager_
=
nullptr
;
that
.
extensions_
.
clear
();
return
*
this
;
}
void
Extensions
::
Reset
(
const
ExtensionManager
*
manager
)
{
if
(
manager
==
manager_
)
return
;
// reuse existing extensions
// Discard current extensions before reassigning the |manager_|.
Clear
();
manager_
=
manager
;
extensions_
.
assign
(
manager_
->
configs_
.
size
(),
nullptr
);
}
void
Extensions
::
Clear
()
{
// NB: This works even if the |manager_| is null, because that only happens
// when |extensions_| is empty.
for
(
size_t
index
=
0
;
index
<
extensions_
.
size
();
++
index
)
{
manager_
->
configs_
[
index
].
deleter
(
extensions_
[
index
]);
}
extensions_
.
clear
();
manager_
=
nullptr
;
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/extensions.h
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Utils for declaring, allocating, and retrieving reusable typed extensions of
// the SessionState. There are two types of extensions:
//
// * Shared extensions, which are shared by all components in a DRAGNN network,
// like the layers in NetworkStates.
//
// * Local extensions, which are private to a particular component in a DRAGNN
// network, like the local operands in NetworkStates.
//
// Extensions are reused across network invocations, so users cannot rely on
// them having any particular state when they are retrieved. For example, a
// std::vector<int> extension could be filled with values from the previous
// invocation when it is retrieved.
//
// To maximize the benefits of reuse, use shared extensions when possible. In
// addition, avoid operations that can deallocate memory. For example, avoid
// resize()-ing a std::vector<std::vector<int>> extension to a smaller size,
// because that deallocates the trailing std::vector<int>s. On the other hand,
// a std::vector<int> extension can be resize()d freely, because that does not
// shrink capacity().
//
// NOTE: Theoretically, shared extensions can be used to pass information down
// the pipeline of components. However, this usage is not a supported and is
// unnecessary since components can already communicate via NetworkStates.
#ifndef DRAGNN_RUNTIME_EXTENSIONS_H_
#define DRAGNN_RUNTIME_EXTENSIONS_H_
#include <stddef.h>
#include <utility>
#include <vector>
#include "tensorflow/core/platform/logging.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// Opaque handles used to access extensions.
template
<
class
T
>
class
SharedExtensionHandle
;
template
<
class
T
>
class
LocalExtensionHandle
;
// A class that manages a set of SessionState extensions.
class
ExtensionManager
{
public:
// Creates an empty manager.
ExtensionManager
()
=
default
;
// Sets |handle| to refer to the shared extension of type |T|, creating it if
// it does not already exist. Calling N times with the same |T| results in N
// handles to the same extension.
template
<
class
T
>
void
GetShared
(
SharedExtensionHandle
<
T
>
*
handle
);
// Sets |handle| to refer to a new local extension of type |T|. The extension
// is "local" in the sense that only the caller knows its handle. Calling N
// times with the same |T| results in N handles to N different extensions.
template
<
class
T
>
void
AddLocal
(
LocalExtensionHandle
<
T
>
*
handle
);
private:
friend
class
Extensions
;
// Function that can delete an untyped pointer using the proper type. All
// |Deleter|s are pointers to instantiations of DeleteAs<T>() below, so this
// can also be used as a type ID.
using
Deleter
=
void
(
*
)(
void
*
);
// Configuration information for an extension.
struct
ExtensionConfig
{
ExtensionConfig
(
bool
is_shared
,
Deleter
deleter
)
:
is_shared
(
is_shared
),
deleter
(
deleter
)
{}
// Whether the extension is shared or local.
const
bool
is_shared
;
// Extension deleter, which also serves as a type ID.
const
Deleter
deleter
;
};
// Deletes the |object| as a |T|. All |Deleter|s point to this function.
template
<
class
T
>
static
void
DeleteAs
(
void
*
object
);
// Implements the non-templated part of GetShared(). Sets |index| to the
// index of the extension whose type matches the |deleter|, adding it if it
// does not already exist.
void
GetSharedImpl
(
Deleter
deleter
,
size_t
*
index
);
// Implements the non-templated part of AddLocal(). Adds an extension that
// uses the |deleter| and sets |index| to its index.
void
AddLocalImpl
(
Deleter
deleter
,
size_t
*
index
);
// Ordered list of configurations for all extensions.
std
::
vector
<
ExtensionConfig
>
configs_
;
};
// A set of SessionState extensions. The extensions are configured by an
// ExtensionManager, and instances of extension can be accessed using the
// handles produced by the manager.
//
// Note that this class is not thread-safe, so only one thread may access any
// particular instance at a time. In normal usage, this will be attached to a
// SessionState and thus single-threaded access is guaranteed.
class
Extensions
{
public:
// Creates an empty set of extensions.
Extensions
()
=
default
;
// Moves all extensions from |that| to this. Afterwards, the extensions in
// this are address-equal to the extensions originally in |that|.
Extensions
(
Extensions
&&
that
);
Extensions
&
operator
=
(
Extensions
&&
that
);
~
Extensions
()
{
Clear
();
}
// Resets this to an empty set configured by the |manager|. The |manager|
// must live until this is destroyed or Reset(), and should not be modified
// during that time.
void
Reset
(
const
ExtensionManager
*
manager
);
// Returns the shared extension associated with the |handle|. Creates the
// extension first via "new T()" if it does not already exist.
template
<
class
T
>
T
&
Get
(
SharedExtensionHandle
<
T
>
handle
);
// Returns the local extension associated with the |handle|. Creates the
// extension first via "new T(args)" if it does not already exist.
template
<
class
T
,
class
...
Args
>
T
&
Get
(
LocalExtensionHandle
<
T
>
handle
,
Args
&&
...
args
);
private:
// Restores this to a just-default-constructed state.
void
Clear
();
// Manager of this set of extensions.
const
ExtensionManager
*
manager_
=
nullptr
;
// Ordered list of per-component operands, aligned with |manager_->configs_|.
std
::
vector
<
void
*>
extensions_
;
};
// Implementation details below.
// An opaque handle to a typed shared extension.
template
<
class
T
>
class
SharedExtensionHandle
{
public:
// Creates an invalid handle.
SharedExtensionHandle
()
=
default
;
private:
friend
class
ExtensionManager
;
friend
class
Extensions
;
// Index of this extension in the Extensions.
size_t
index_
=
SIZE_MAX
;
};
// An opaque handle to a typed local extension.
template
<
class
T
>
class
LocalExtensionHandle
{
public:
// Creates an invalid handle.
LocalExtensionHandle
()
=
default
;
private:
friend
class
ExtensionManager
;
friend
class
Extensions
;
// Index of this extension in the Extensions.
size_t
index_
=
SIZE_MAX
;
};
template
<
class
T
>
void
ExtensionManager
::
DeleteAs
(
void
*
object
)
{
delete
reinterpret_cast
<
T
*>
(
object
);
}
template
<
class
T
>
void
ExtensionManager
::
GetShared
(
SharedExtensionHandle
<
T
>
*
handle
)
{
GetSharedImpl
(
&
DeleteAs
<
T
>
,
&
handle
->
index_
);
}
template
<
class
T
>
void
ExtensionManager
::
AddLocal
(
LocalExtensionHandle
<
T
>
*
handle
)
{
AddLocalImpl
(
&
DeleteAs
<
T
>
,
&
handle
->
index_
);
}
template
<
class
T
>
T
&
Extensions
::
Get
(
SharedExtensionHandle
<
T
>
handle
)
{
DCHECK
(
manager_
->
configs_
[
handle
.
index_
].
is_shared
);
DCHECK_EQ
(
manager_
->
configs_
[
handle
.
index_
].
deleter
,
&
ExtensionManager
::
DeleteAs
<
T
>
);
void
*&
extension
=
extensions_
[
handle
.
index_
];
if
(
extension
==
nullptr
)
extension
=
new
T
();
return
*
reinterpret_cast
<
T
*>
(
extension
);
}
template
<
class
T
,
class
...
Args
>
T
&
Extensions
::
Get
(
LocalExtensionHandle
<
T
>
handle
,
Args
&&
...
args
)
{
DCHECK
(
!
manager_
->
configs_
[
handle
.
index_
].
is_shared
);
DCHECK_EQ
(
manager_
->
configs_
[
handle
.
index_
].
deleter
,
&
ExtensionManager
::
DeleteAs
<
T
>
);
void
*&
extension
=
extensions_
[
handle
.
index_
];
if
(
extension
==
nullptr
)
extension
=
new
T
(
std
::
forward
<
Args
>
(
args
)...);
return
*
reinterpret_cast
<
T
*>
(
extension
);
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_EXTENSIONS_H_
research/syntaxnet/dragnn/runtime/extensions_test.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/extensions.h"
#include <utility>
#include <vector>
#include <gmock/gmock.h>
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
using
::
testing
::
ElementsAre
;
// Dummy struct for tests.
struct
Foo
{
Foo
()
=
default
;
explicit
Foo
(
float
real
,
int
num
)
:
real
(
real
)
{
for
(
int
i
=
0
;
i
<
num
;
++
i
)
ints
.
push_back
(
i
);
}
float
real
=
0.0
;
std
::
vector
<
int
>
ints
;
};
// Returns a shared extension handle from the |manager|.
template
<
class
T
>
SharedExtensionHandle
<
T
>
GetShared
(
ExtensionManager
*
manager
)
{
SharedExtensionHandle
<
T
>
handle
;
manager
->
GetShared
(
&
handle
);
return
handle
;
}
// Returns a local extension handle from the |manager|.
template
<
class
T
>
LocalExtensionHandle
<
T
>
AddLocal
(
ExtensionManager
*
manager
)
{
LocalExtensionHandle
<
T
>
handle
;
manager
->
AddLocal
(
&
handle
);
return
handle
;
}
// Tests that GetShared() reuses existing extensions.
TEST
(
ExtensionManagerTest
,
GetShared
)
{
ExtensionManager
manager
;
const
auto
foo_handle1
=
GetShared
<
Foo
>
(
&
manager
);
const
auto
int_handle
=
GetShared
<
int
>
(
&
manager
);
const
auto
foo_handle2
=
GetShared
<
Foo
>
(
&
manager
);
Extensions
extensions
;
extensions
.
Reset
(
&
manager
);
Foo
&
foo1
=
extensions
.
Get
(
foo_handle1
);
Foo
&
foo2
=
extensions
.
Get
(
foo_handle2
);
EXPECT_EQ
(
&
foo1
,
&
foo2
);
EXPECT_EQ
(
foo1
.
real
,
0.0
);
EXPECT_TRUE
(
foo1
.
ints
.
empty
());
EXPECT_EQ
(
extensions
.
Get
(
int_handle
),
0
);
// T() zero-initializes POD
}
// Tests that AddLocal() always adds a new extension.
TEST
(
ExtensionManagerTest
,
AddLocal
)
{
ExtensionManager
manager
;
const
auto
foo_handle1
=
AddLocal
<
Foo
>
(
&
manager
);
const
auto
int_handle
=
AddLocal
<
int
>
(
&
manager
);
const
auto
foo_handle2
=
AddLocal
<
Foo
>
(
&
manager
);
Extensions
extensions
;
extensions
.
Reset
(
&
manager
);
Foo
&
foo1
=
extensions
.
Get
(
foo_handle1
);
Foo
&
foo2
=
extensions
.
Get
(
foo_handle2
);
EXPECT_NE
(
&
foo1
,
&
foo2
);
EXPECT_EQ
(
foo1
.
real
,
0.0
);
EXPECT_EQ
(
foo2
.
real
,
0.0
);
EXPECT_TRUE
(
foo1
.
ints
.
empty
());
EXPECT_TRUE
(
foo2
.
ints
.
empty
());
EXPECT_EQ
(
extensions
.
Get
(
int_handle
),
0
);
// T() zero-initializes POD
}
// Tests that Get() always returns the same object.
TEST
(
ExtensionManagerTest
,
GetReturnsSameObject
)
{
ExtensionManager
manager
;
const
auto
foo_shared
=
GetShared
<
Foo
>
(
&
manager
);
const
auto
int_shared
=
GetShared
<
int
>
(
&
manager
);
const
auto
foo_local
=
AddLocal
<
Foo
>
(
&
manager
);
const
auto
int_local
=
AddLocal
<
int
>
(
&
manager
);
Extensions
extensions
;
extensions
.
Reset
(
&
manager
);
Foo
&
foo_shared1
=
extensions
.
Get
(
foo_shared
);
int
&
int_shared1
=
extensions
.
Get
(
int_shared
);
Foo
&
foo_local1
=
extensions
.
Get
(
foo_local
);
int
&
int_local1
=
extensions
.
Get
(
int_local
);
Foo
&
foo_shared2
=
extensions
.
Get
(
foo_shared
);
int
&
int_shared2
=
extensions
.
Get
(
int_shared
);
Foo
&
foo_local2
=
extensions
.
Get
(
foo_local
);
int
&
int_local2
=
extensions
.
Get
(
int_local
);
EXPECT_EQ
(
&
foo_shared1
,
&
foo_shared2
);
EXPECT_EQ
(
&
int_shared1
,
&
int_shared2
);
EXPECT_EQ
(
&
foo_local1
,
&
foo_local2
);
EXPECT_EQ
(
&
int_local1
,
&
int_local2
);
}
// Tests that local extensions can use non-default constructors.
TEST
(
ExtensionManagerTest
,
LocalAllowsNonDefaultConstructor
)
{
ExtensionManager
manager
;
const
auto
foo_handle
=
AddLocal
<
Foo
>
(
&
manager
);
const
auto
int_handle
=
AddLocal
<
int
>
(
&
manager
);
Extensions
extensions
;
extensions
.
Reset
(
&
manager
);
// Use non-default constructors to get initialized values.
Foo
&
foo1
=
extensions
.
Get
(
foo_handle
,
0.5
,
5
);
EXPECT_EQ
(
foo1
.
real
,
0.5
);
EXPECT_THAT
(
foo1
.
ints
,
ElementsAre
(
0
,
1
,
2
,
3
,
4
));
EXPECT_EQ
(
extensions
.
Get
(
int_handle
,
-
123
),
-
123
);
// However, once created, the non-default constructor args are ignored.
Foo
&
foo2
=
extensions
.
Get
(
foo_handle
,
1.23
,
1000
);
EXPECT_EQ
(
foo2
.
real
,
0.5
);
EXPECT_THAT
(
foo2
.
ints
,
ElementsAre
(
0
,
1
,
2
,
3
,
4
));
EXPECT_EQ
(
extensions
.
Get
(
int_handle
,
-
456
),
-
123
);
}
// Tests that calling Reset() with the same manager is a NOP.
TEST
(
ExtensionManagerTest
,
ResetWithSameManager
)
{
ExtensionManager
manager
;
const
auto
foo_shared
=
GetShared
<
Foo
>
(
&
manager
);
const
auto
int_shared
=
GetShared
<
int
>
(
&
manager
);
const
auto
foo_local
=
AddLocal
<
Foo
>
(
&
manager
);
const
auto
int_local
=
AddLocal
<
int
>
(
&
manager
);
Extensions
extensions
;
extensions
.
Reset
(
&
manager
);
Foo
&
foo_shared1
=
extensions
.
Get
(
foo_shared
);
int
&
int_shared1
=
extensions
.
Get
(
int_shared
);
Foo
&
foo_local1
=
extensions
.
Get
(
foo_local
);
int
&
int_local1
=
extensions
.
Get
(
int_local
);
extensions
.
Reset
(
&
manager
);
Foo
&
foo_shared2
=
extensions
.
Get
(
foo_shared
);
int
&
int_shared2
=
extensions
.
Get
(
int_shared
);
Foo
&
foo_local2
=
extensions
.
Get
(
foo_local
);
int
&
int_local2
=
extensions
.
Get
(
int_local
);
EXPECT_EQ
(
&
foo_shared1
,
&
foo_shared2
);
EXPECT_EQ
(
&
int_shared1
,
&
int_shared2
);
EXPECT_EQ
(
&
foo_local1
,
&
foo_local2
);
EXPECT_EQ
(
&
int_local1
,
&
int_local2
);
}
// Tests that Reset() can be used to switch managers.
TEST
(
ExtensionManagerTest
,
ResetWithDifferentManager
)
{
ExtensionManager
manager1
;
const
auto
foo_shared
=
GetShared
<
Foo
>
(
&
manager1
);
const
auto
foo_local
=
AddLocal
<
Foo
>
(
&
manager1
);
ExtensionManager
manager2
;
const
auto
int_shared
=
GetShared
<
int
>
(
&
manager2
);
const
auto
int_local
=
AddLocal
<
int
>
(
&
manager2
);
Extensions
extensions
;
extensions
.
Reset
(
&
manager1
);
EXPECT_EQ
(
extensions
.
Get
(
foo_shared
).
real
,
0.0
);
EXPECT_EQ
(
extensions
.
Get
(
foo_local
,
0.75
,
3
).
real
,
0.75
);
extensions
.
Reset
(
&
manager2
);
EXPECT_EQ
(
extensions
.
Get
(
int_shared
),
0
);
EXPECT_EQ
(
extensions
.
Get
(
int_local
,
5
),
5
);
}
// Tests that Extensions supports move construction.
TEST
(
ExtensionManagerTest
,
MoveConstruction
)
{
ExtensionManager
manager
;
const
auto
foo_shared
=
GetShared
<
Foo
>
(
&
manager
);
const
auto
int_shared
=
GetShared
<
int
>
(
&
manager
);
const
auto
foo_local
=
AddLocal
<
Foo
>
(
&
manager
);
const
auto
int_local
=
AddLocal
<
int
>
(
&
manager
);
// Add a couple more spurious extensions that are never set, to exercise
// movement of non-present extensions.
GetShared
<
float
>
(
&
manager
);
AddLocal
<
float
>
(
&
manager
);
Extensions
extensions1
;
extensions1
.
Reset
(
&
manager
);
Foo
&
foo_shared1
=
extensions1
.
Get
(
foo_shared
);
int
&
int_shared1
=
extensions1
.
Get
(
int_shared
);
Foo
&
foo_local1
=
extensions1
.
Get
(
foo_local
);
int
&
int_local1
=
extensions1
.
Get
(
int_local
);
Extensions
extensions2
=
std
::
move
(
extensions1
);
Foo
&
foo_shared2
=
extensions2
.
Get
(
foo_shared
);
int
&
int_shared2
=
extensions2
.
Get
(
int_shared
);
Foo
&
foo_local2
=
extensions2
.
Get
(
foo_local
);
int
&
int_local2
=
extensions2
.
Get
(
int_local
);
EXPECT_EQ
(
&
foo_shared1
,
&
foo_shared2
);
EXPECT_EQ
(
&
int_shared1
,
&
int_shared2
);
EXPECT_EQ
(
&
foo_local1
,
&
foo_local2
);
EXPECT_EQ
(
&
int_local1
,
&
int_local2
);
}
// Tests that Extensions supports move assignment.
TEST
(
ExtensionManagerTest
,
MoveAssignment
)
{
ExtensionManager
manager1
;
const
auto
foo_shared
=
GetShared
<
Foo
>
(
&
manager1
);
const
auto
foo_local
=
AddLocal
<
Foo
>
(
&
manager1
);
ExtensionManager
manager2
;
const
auto
int_shared
=
GetShared
<
int
>
(
&
manager2
);
const
auto
int_local
=
AddLocal
<
int
>
(
&
manager2
);
// Add a couple more spurious extensions that are never set, to exercise
// movement of non-present extensions.
GetShared
<
float
>
(
&
manager1
);
GetShared
<
float
>
(
&
manager2
);
AddLocal
<
float
>
(
&
manager1
);
AddLocal
<
float
>
(
&
manager2
);
// Fill two sets of extensions.
Extensions
extensions1
;
extensions1
.
Reset
(
&
manager1
);
extensions1
.
Get
(
foo_shared
).
real
=
1.0
;
extensions1
.
Get
(
foo_local
).
real
=
1.0
;
Extensions
extensions2
;
extensions2
.
Reset
(
&
manager2
);
extensions2
.
Get
(
int_shared
)
=
2
;
extensions2
.
Get
(
int_local
)
=
2
;
// Use a third set of extensions to perform a swap.
Extensions
extensions3
;
extensions3
=
std
::
move
(
extensions1
);
extensions1
=
std
::
move
(
extensions2
);
extensions2
=
std
::
move
(
extensions3
);
EXPECT_EQ
(
extensions1
.
Get
(
int_shared
),
2
);
EXPECT_EQ
(
extensions1
.
Get
(
int_local
),
2
);
EXPECT_EQ
(
extensions2
.
Get
(
foo_shared
).
real
,
1.0
);
EXPECT_EQ
(
extensions2
.
Get
(
foo_local
).
real
,
1.0
);
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/feed_forward_network.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <string>
#include "dragnn/core/compute_session.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/feed_forward_network_kernel.h"
#include "dragnn/runtime/feed_forward_network_layer.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/network_unit.h"
#include "dragnn/runtime/network_unit_base.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// A network unit that evaluates a feed-forward multi-layer perceptron.
class
FeedForwardNetwork
:
public
NetworkUnitBase
{
public:
// Implements NetworkUnit.
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
override
;
string
GetLogitsName
()
const
override
{
return
kernel_
.
logits_name
();
}
tensorflow
::
Status
Evaluate
(
size_t
step_index
,
SessionState
*
session_state
,
ComputeSession
*
compute_session
)
const
override
;
private:
// Kernel that implements the feed-forward network.
FeedForwardNetworkKernel
kernel_
;
};
tensorflow
::
Status
FeedForwardNetwork
::
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
{
TF_RETURN_IF_ERROR
(
kernel_
.
Initialize
(
component_spec
,
variable_store
,
network_state_manager
));
const
bool
use_concatenated_input
=
true
;
TF_RETURN_IF_ERROR
(
InitializeBase
(
use_concatenated_input
,
component_spec
,
variable_store
,
network_state_manager
,
extension_manager
));
// Check dimensions across layers. This must be done after InitializeBase(),
// when concatenated_input_dim() is known.
return
kernel_
.
ValidateInputDimension
(
concatenated_input_dim
());
}
tensorflow
::
Status
FeedForwardNetwork
::
Evaluate
(
size_t
step_index
,
SessionState
*
session_state
,
ComputeSession
*
compute_session
)
const
{
Vector
<
float
>
input
;
TF_RETURN_IF_ERROR
(
EvaluateBase
(
session_state
,
compute_session
,
&
input
));
for
(
const
FeedForwardNetworkLayer
&
layer
:
kernel_
.
layers
())
{
input
=
layer
.
Apply
(
input
,
session_state
->
network_states
,
step_index
);
}
return
tensorflow
::
Status
::
OK
();
}
DRAGNN_RUNTIME_REGISTER_NETWORK_UNIT
(
FeedForwardNetwork
);
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/feed_forward_network_kernel.cc
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/feed_forward_network_kernel.h"
#include "dragnn/runtime/activation_functions.h"
#include "dragnn/runtime/attributes.h"
#include "dragnn/runtime/transition_system_traits.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Attributes used by the feed-forward network.
struct
FeedForwardNetworkAttributes
:
public
Attributes
{
// Hidden layer sizes; e.g., "64,64,32".
Optional
<
std
::
vector
<
size_t
>>
hidden_layer_sizes
{
"hidden_layer_sizes"
,
{},
this
};
// Whether to omit the "logits" layer.
Optional
<
bool
>
omit_logits
{
"omit_logits"
,
false
,
this
};
// Only the default settings are supported for these attributes.
Optional
<
bool
>
layer_norm_input
{
"layer_norm_input"
,
false
,
this
};
Optional
<
bool
>
layer_norm_hidden
{
"layer_norm_hidden"
,
false
,
this
};
Optional
<
string
>
nonlinearity
{
"nonlinearity"
,
"relu"
,
this
};
// Training-only attributes, ignored in the runtime.
Ignored
dropout_keep_prob
{
"dropout_keep_prob"
,
this
};
Ignored
dropout_per_sequence
{
"dropout_per_sequence"
,
this
};
Ignored
dropout_all_layers
{
"dropout_all_layers"
,
this
};
Ignored
initialize_bias_zero
{
"initialize_bias_zero"
,
this
};
Ignored
initialize_softmax_zero
{
"initialize_softmax_zero"
,
this
};
Ignored
initialize_hidden_orthogonal
{
"initialize_hidden_orthogonal"
,
this
};
};
}
// namespace
tensorflow
::
Status
FeedForwardNetworkKernel
::
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
)
{
FeedForwardNetworkAttributes
attributes
;
TF_RETURN_IF_ERROR
(
attributes
.
Reset
(
component_spec
.
network_unit
().
parameters
()));
// Check for unsupported attribute values.
if
(
attributes
.
layer_norm_input
()
||
attributes
.
layer_norm_hidden
())
{
return
tensorflow
::
errors
::
Unimplemented
(
"Layer norm is not supported"
);
}
if
(
attributes
.
nonlinearity
()
!=
"relu"
)
{
return
tensorflow
::
errors
::
Unimplemented
(
"Non-linearity is not supported: "
,
attributes
.
nonlinearity
());
}
// Add all hidden layers.
for
(
const
size_t
hidden_layer_size
:
attributes
.
hidden_layer_sizes
())
{
const
size_t
height
=
layers_
.
size
();
layers_
.
emplace_back
();
TF_RETURN_IF_ERROR
(
layers_
.
back
().
Initialize
(
component_spec
.
name
(),
tensorflow
::
strings
::
StrCat
(
"layer_"
,
height
),
hidden_layer_size
,
ActivationFunction
::
kRelu
,
tensorflow
::
strings
::
StrCat
(
height
),
variable_store
,
network_state_manager
));
}
// Add "last_layer" as an alias for the last hidden layer, if any.
if
(
!
layers_
.
empty
())
{
TF_RETURN_IF_ERROR
(
network_state_manager
->
AddLayerAlias
(
"last_layer"
,
tensorflow
::
strings
::
StrCat
(
"layer_"
,
layers_
.
size
()
-
1
)));
}
// Add a linear "logits" layer, if necessary.
const
bool
has_logits
=
!
TransitionSystemTraits
(
component_spec
).
is_deterministic
&&
!
attributes
.
omit_logits
();
if
(
has_logits
)
{
logits_name_
=
FeedForwardNetworkLayer
::
kLogitsName
;
layers_
.
emplace_back
();
TF_RETURN_IF_ERROR
(
layers_
.
back
().
InitializeSoftmax
(
component_spec
,
variable_store
,
network_state_manager
));
}
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
FeedForwardNetworkKernel
::
ValidateInputDimension
(
size_t
dimension
)
const
{
for
(
const
FeedForwardNetworkLayer
&
layer
:
layers_
)
{
TF_RETURN_IF_ERROR
(
layer
.
CheckInputDimAndGetOutputDim
(
dimension
,
&
dimension
));
}
return
tensorflow
::
Status
::
OK
();
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/feed_forward_network_kernel.h
0 → 100644
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_FEED_FORWARD_NETWORK_KERNEL_H_
#define DRAGNN_RUNTIME_FEED_FORWARD_NETWORK_KERNEL_H_
#include <stddef.h>
#include <string>
#include <vector>
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/feed_forward_network_layer.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// A kernel that evaluates a multi-layer perceptron.
class
FeedForwardNetworkKernel
{
public:
// Initializes this to the configuration in the |component_spec|. Retrieves
// pre-trained variables from the |variable_store|, which must outlive this.
// Adds layers and local operands to the |network_state_manager|, which must
// be positioned at the current component. On error, returns non-OK.
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
);
// Returns OK iff this is compatible with the input |dimension|.
tensorflow
::
Status
ValidateInputDimension
(
size_t
dimension
)
const
;
// Accessors.
const
std
::
vector
<
FeedForwardNetworkLayer
>
&
layers
()
const
{
return
layers_
;
}
const
string
&
logits_name
()
const
{
return
logits_name_
;
}
private:
// List of layers, including hidden layers and the logits, if any.
std
::
vector
<
FeedForwardNetworkLayer
>
layers_
;
// Name of the logits layer.
string
logits_name_
;
};
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_FEED_FORWARD_NETWORK_KERNEL_H_
Prev
1
…
3
4
5
6
7
8
9
10
11
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment