Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
edea2b67
Commit
edea2b67
authored
May 11, 2018
by
Terry Koo
Browse files
Remove runtime because reasons.
parent
a4bb31d0
Changes
291
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
2295 deletions
+0
-2295
research/syntaxnet/dragnn/runtime/bulk_lstm_network.cc
research/syntaxnet/dragnn/runtime/bulk_lstm_network.cc
+0
-65
research/syntaxnet/dragnn/runtime/bulk_lstm_network_test.cc
research/syntaxnet/dragnn/runtime/bulk_lstm_network_test.cc
+0
-166
research/syntaxnet/dragnn/runtime/bulk_network_unit.cc
research/syntaxnet/dragnn/runtime/bulk_network_unit.cc
+0
-44
research/syntaxnet/dragnn/runtime/bulk_network_unit.h
research/syntaxnet/dragnn/runtime/bulk_network_unit.h
+0
-101
research/syntaxnet/dragnn/runtime/bulk_network_unit_test.cc
research/syntaxnet/dragnn/runtime/bulk_network_unit_test.cc
+0
-89
research/syntaxnet/dragnn/runtime/clear_dropout_component_transformer.cc
...net/dragnn/runtime/clear_dropout_component_transformer.cc
+0
-48
research/syntaxnet/dragnn/runtime/clear_dropout_component_transformer_test.cc
...ragnn/runtime/clear_dropout_component_transformer_test.cc
+0
-62
research/syntaxnet/dragnn/runtime/component.cc
research/syntaxnet/dragnn/runtime/component.cc
+0
-107
research/syntaxnet/dragnn/runtime/component.h
research/syntaxnet/dragnn/runtime/component.h
+0
-111
research/syntaxnet/dragnn/runtime/component_test.cc
research/syntaxnet/dragnn/runtime/component_test.cc
+0
-201
research/syntaxnet/dragnn/runtime/component_transformation.cc
...arch/syntaxnet/dragnn/runtime/component_transformation.cc
+0
-91
research/syntaxnet/dragnn/runtime/component_transformation.h
research/syntaxnet/dragnn/runtime/component_transformation.h
+0
-86
research/syntaxnet/dragnn/runtime/component_transformation_test.cc
...syntaxnet/dragnn/runtime/component_transformation_test.cc
+0
-241
research/syntaxnet/dragnn/runtime/conversion.cc
research/syntaxnet/dragnn/runtime/conversion.cc
+0
-82
research/syntaxnet/dragnn/runtime/conversion.h
research/syntaxnet/dragnn/runtime/conversion.h
+0
-58
research/syntaxnet/dragnn/runtime/conversion_test.cc
research/syntaxnet/dragnn/runtime/conversion_test.cc
+0
-140
research/syntaxnet/dragnn/runtime/converter.cc
research/syntaxnet/dragnn/runtime/converter.cc
+0
-145
research/syntaxnet/dragnn/runtime/converter_test.sh
research/syntaxnet/dragnn/runtime/converter_test.sh
+0
-92
research/syntaxnet/dragnn/runtime/dynamic_component.cc
research/syntaxnet/dragnn/runtime/dynamic_component.cc
+0
-173
research/syntaxnet/dragnn/runtime/dynamic_component_test.cc
research/syntaxnet/dragnn/runtime/dynamic_component_test.cc
+0
-193
No files found.
Too many changes to show.
To preserve performance only
291 of 291+
files are displayed.
Plain diff
Email patch
research/syntaxnet/dragnn/runtime/bulk_lstm_network.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <string>
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/bulk_network_unit.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/lstm_network_kernel.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// A network unit that evaluates an LSTM.
class
BulkLSTMNetwork
:
public
BulkNetworkUnit
{
public:
// Implements BulkNetworkUnit.
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
override
{
return
kernel_
.
Initialize
(
component_spec
,
variable_store
,
network_state_manager
,
extension_manager
);
}
tensorflow
::
Status
ValidateInputDimension
(
size_t
dimension
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
string
GetLogitsName
()
const
override
{
return
kernel_
.
GetLogitsName
();
}
tensorflow
::
Status
Evaluate
(
Matrix
<
float
>
inputs
,
SessionState
*
session_state
)
const
override
{
return
kernel_
.
Apply
(
inputs
,
session_state
);
}
private:
// Kernel that implements the LSTM.
LSTMNetworkKernel
kernel_
{
/*bulk=*/
true
};
};
DRAGNN_RUNTIME_REGISTER_BULK_NETWORK_UNIT
(
BulkLSTMNetwork
);
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/bulk_lstm_network_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/runtime.pb.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/bulk_network_unit.h"
#include "dragnn/runtime/flexible_matrix_kernel.h"
#include "dragnn/runtime/lstm_cell/cell_function.h"
#include "dragnn/runtime/test/helpers.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
constexpr
size_t
kNumSteps
=
20
;
constexpr
size_t
kNumActions
=
10
;
constexpr
size_t
kInputDim
=
32
;
constexpr
size_t
kHiddenDim
=
8
;
class
BulkLSTMNetworkTest
:
public
NetworkTestBase
{
protected:
// Adds a blocked weight matrix with the |name| with the given dimensions and
// |fill_value|. If |is_flexible_matrix| is true, the variable is set up for
// use by the FlexibleMatrixKernel.
void
AddWeights
(
const
string
&
name
,
size_t
input_dim
,
size_t
output_dim
,
float
fill_value
,
bool
is_flexible_matrix
=
false
)
{
constexpr
int
kBatchSize
=
LstmCellFunction
<>::
kBatchSize
;
size_t
output_padded
=
kBatchSize
*
((
output_dim
+
kBatchSize
-
1
)
/
kBatchSize
);
size_t
num_views
=
(
output_padded
/
kBatchSize
)
*
input_dim
;
string
var_name
=
tensorflow
::
strings
::
StrCat
(
kTestComponentName
,
"/"
,
name
,
is_flexible_matrix
?
FlexibleMatrixKernel
::
kSuffix
:
"/matrix/blocked48"
);
const
std
::
vector
<
float
>
block
(
kBatchSize
,
fill_value
);
const
std
::
vector
<
std
::
vector
<
float
>>
blocks
(
num_views
,
block
);
variable_store_
.
AddOrDie
(
var_name
,
blocks
,
VariableSpec
::
FORMAT_COLUMN_BLOCKED_ROW_MAJOR_MATRIX
);
variable_store_
.
SetBlockedDimensionOverride
(
var_name
,
{
input_dim
,
output_padded
,
kBatchSize
});
}
// Adds a bias vector with the |name_suffix| with the given dimensions and
// |fill_value|.
void
AddBiases
(
const
string
&
name
,
size_t
dimension
,
float
fill_value
)
{
const
string
biases_name
=
tensorflow
::
strings
::
StrCat
(
kTestComponentName
,
"/"
,
name
);
AddVectorVariable
(
biases_name
,
dimension
,
fill_value
);
}
// Initializes the |bulk_network_unit_| from the |component_spec_text|. On
// error, returns non-OK.
tensorflow
::
Status
Initialize
(
const
string
&
component_spec_text
)
{
ComponentSpec
component_spec
;
CHECK
(
TextFormat
::
ParseFromString
(
component_spec_text
,
&
component_spec
));
component_spec
.
set_name
(
kTestComponentName
);
AddComponent
(
kTestComponentName
);
TF_RETURN_IF_ERROR
(
BulkNetworkUnit
::
CreateOrError
(
"BulkLSTMNetwork"
,
&
bulk_network_unit_
));
TF_RETURN_IF_ERROR
(
bulk_network_unit_
->
Initialize
(
component_spec
,
&
variable_store_
,
&
network_state_manager_
,
&
extension_manager_
));
TF_RETURN_IF_ERROR
(
bulk_network_unit_
->
ValidateInputDimension
(
kInputDim
));
network_states_
.
Reset
(
&
network_state_manager_
);
StartComponent
(
kNumSteps
);
session_state_
.
extensions
.
Reset
(
&
extension_manager_
);
return
tensorflow
::
Status
::
OK
();
}
// Evaluates the |bulk_network_unit_| on the |inputs|.
void
Apply
(
const
std
::
vector
<
std
::
vector
<
float
>>
&
inputs
)
{
UniqueMatrix
<
float
>
input_matrix
(
inputs
);
TF_ASSERT_OK
(
bulk_network_unit_
->
Evaluate
(
Matrix
<
float
>
(
*
input_matrix
),
&
session_state_
));
}
// Returns the logits matrix.
Matrix
<
float
>
GetLogits
()
const
{
return
Matrix
<
float
>
(
GetLayer
(
kTestComponentName
,
"logits"
));
}
std
::
unique_ptr
<
BulkNetworkUnit
>
bulk_network_unit_
;
};
TEST_F
(
BulkLSTMNetworkTest
,
NormalOperation
)
{
const
string
kSpec
=
R"(fixed_feature {
vocabulary_size: 50
embedding_dim: 32
size: 1
}
network_unit {
parameters {
key: 'hidden_layer_sizes'
value: '8'
}
}
num_actions: 10)"
;
constexpr
float
kEmbedding
=
1.25
;
constexpr
float
kWeight
=
1.5
;
// Same as above, with "softmax" weights and biases.
AddWeights
(
"x_to_ico"
,
kInputDim
,
3
*
kHiddenDim
,
kWeight
);
AddWeights
(
"h_to_ico"
,
kHiddenDim
,
3
*
kHiddenDim
,
kWeight
);
AddWeights
(
"c2i"
,
kHiddenDim
,
kHiddenDim
,
kWeight
);
AddWeights
(
"c2o"
,
kHiddenDim
,
kHiddenDim
,
kWeight
);
AddWeights
(
"weights_softmax"
,
kHiddenDim
,
kNumActions
,
kWeight
,
/*is_flexible_matrix=*/
true
);
AddBiases
(
"ico_bias"
,
3
*
kHiddenDim
,
kWeight
);
AddBiases
(
"bias_softmax"
,
kNumActions
,
kWeight
);
TF_EXPECT_OK
(
Initialize
(
kSpec
));
// Logits should exist.
EXPECT_EQ
(
bulk_network_unit_
->
GetLogitsName
(),
"logits"
);
const
std
::
vector
<
float
>
row
(
kInputDim
,
kEmbedding
);
const
std
::
vector
<
std
::
vector
<
float
>>
rows
(
kNumSteps
,
row
);
Apply
(
rows
);
// Logits dimension matches "num_actions" above. We don't test the values very
// precisely here, and feel free to update if the cell function changes. Most
// value tests should be in lstm_cell/cell_function_test.cc.
Matrix
<
float
>
logits
=
GetLogits
();
EXPECT_EQ
(
logits
.
num_rows
(),
kNumSteps
);
EXPECT_EQ
(
logits
.
num_columns
(),
kNumActions
);
EXPECT_NEAR
(
logits
.
row
(
0
)[
0
],
10.6391
,
0.1
);
for
(
int
row
=
0
;
row
<
logits
.
num_rows
();
++
row
)
{
for
(
const
float
value
:
logits
.
row
(
row
))
{
EXPECT_EQ
(
value
,
logits
.
row
(
0
)[
0
])
<<
"With uniform weights, all logits should be equal."
;
}
}
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/bulk_network_unit.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/bulk_network_unit.h"
#include <vector>
#include "dragnn/runtime/network_unit.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
string
BulkNetworkUnit
::
GetClassName
(
const
ComponentSpec
&
component_spec
)
{
// The network unit name specified in the |component_spec| is for the Python
// registry and cannot be passed directly to the C++ registry. The function
// below extracts the C++ registered name; e.g.,
// "some.module.FooNetwork" => "FooNetwork".
// We then prepend "Bulk" to distinguish it from the non-bulk version.
return
tensorflow
::
strings
::
StrCat
(
"Bulk"
,
NetworkUnit
::
GetClassName
(
component_spec
));
}
}
// namespace runtime
}
// namespace dragnn
REGISTER_SYNTAXNET_CLASS_REGISTRY
(
"DRAGNN Runtime Bulk Network Unit"
,
dragnn
::
runtime
::
BulkNetworkUnit
);
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/bulk_network_unit.h
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_BULK_NETWORK_UNIT_H_
#define DRAGNN_RUNTIME_BULK_NETWORK_UNIT_H_
#include <stddef.h>
#include <functional>
#include <memory>
#include <string>
#include "dragnn/core/compute_session.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "syntaxnet/registry.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// Interface for network units for bulk inference.
//
// TODO(googleuser): The current approach assumes that fixed and
// linked embeddings are computed and concatenated outside the network unit,
// which is simple and composable. However, it could be more efficient to,
// e.g., pass the fixed and linked embeddings individually or compute them
// internally. That would elide the concatenation and could increase cache
// coherency.
class
BulkNetworkUnit
:
public
RegisterableClass
<
BulkNetworkUnit
>
{
public:
BulkNetworkUnit
(
const
BulkNetworkUnit
&
that
)
=
delete
;
BulkNetworkUnit
&
operator
=
(
const
BulkNetworkUnit
&
that
)
=
delete
;
virtual
~
BulkNetworkUnit
()
=
default
;
// Returns the bulk network unit class name specified in the |component_spec|.
static
string
GetClassName
(
const
ComponentSpec
&
component_spec
);
// Initializes this to the configuration in the |component_spec|. Retrieves
// pre-trained variables from the |variable_store|, which must outlive this.
// Adds layers and local operands to the |network_state_manager|, which must
// be positioned at the current component. Requests SessionState extensions
// from the |extension_manager|. On error, returns non-OK.
virtual
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
=
0
;
// Returns OK iff this is compatible with the input |dimension|.
virtual
tensorflow
::
Status
ValidateInputDimension
(
size_t
dimension
)
const
=
0
;
// Returns the name of the layer that contains classification logits, or an
// empty string if this does not produce logits. Requires that Initialize()
// was called.
virtual
string
GetLogitsName
()
const
=
0
;
// Evaluates this network on the bulk |inputs|, using intermediate operands
// and output layers in the |session_state|. On error, returns non-OK.
virtual
tensorflow
::
Status
Evaluate
(
Matrix
<
float
>
inputs
,
SessionState
*
session_state
)
const
=
0
;
protected:
BulkNetworkUnit
()
=
default
;
private:
// Helps prevent use of the Create() method; use CreateOrError() instead.
using
RegisterableClass
<
BulkNetworkUnit
>::
Create
;
};
}
// namespace runtime
}
// namespace dragnn
DECLARE_SYNTAXNET_CLASS_REGISTRY
(
"DRAGNN Runtime Bulk Network Unit"
,
dragnn
::
runtime
::
BulkNetworkUnit
);
}
// namespace syntaxnet
// Registers a subclass using its class name as a string.
#define DRAGNN_RUNTIME_REGISTER_BULK_NETWORK_UNIT(subclass) \
REGISTER_SYNTAXNET_CLASS_COMPONENT( \
::syntaxnet::dragnn::runtime::BulkNetworkUnit, #subclass, subclass)
#endif // DRAGNN_RUNTIME_BULK_NETWORK_UNIT_H_
research/syntaxnet/dragnn/runtime/bulk_network_unit_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/bulk_network_unit.h"
#include <memory>
#include <string>
#include "dragnn/core/compute_session.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Expects that the two pointers have the same address.
void
ExpectSameAddress
(
const
void
*
pointer1
,
const
void
*
pointer2
)
{
EXPECT_EQ
(
pointer1
,
pointer2
);
}
// A trivial implementation for tests.
class
BulkFooNetwork
:
public
BulkNetworkUnit
{
public:
// Implements BulkNetworkUnit.
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
ValidateInputDimension
(
size_t
dimension
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
string
GetLogitsName
()
const
override
{
return
"foo_logits"
;
}
tensorflow
::
Status
Evaluate
(
Matrix
<
float
>
inputs
,
SessionState
*
session_state
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_BULK_NETWORK_UNIT
(
BulkFooNetwork
);
// Tests that BulkNetworkUnit::GetClassName() resolves names properly.
TEST
(
BulkNetworkUnitTest
,
GetClassName
)
{
for
(
const
string
&
registered_name
:
{
"FooNetwork"
,
"module.FooNetwork"
,
"some.long.path.to.module.FooNetwork"
})
{
ComponentSpec
component_spec
;
component_spec
.
mutable_network_unit
()
->
set_registered_name
(
registered_name
);
EXPECT_EQ
(
BulkNetworkUnit
::
GetClassName
(
component_spec
),
"BulkFooNetwork"
);
}
}
// Tests that BulkNetworkUnits can be created via the registry.
TEST
(
BulkNetworkUnitTest
,
CreateOrError
)
{
std
::
unique_ptr
<
BulkNetworkUnit
>
foo
;
TF_ASSERT_OK
(
BulkNetworkUnit
::
CreateOrError
(
"BulkFooNetwork"
,
&
foo
));
ASSERT_TRUE
(
foo
!=
nullptr
);
ExpectSameAddress
(
dynamic_cast
<
BulkFooNetwork
*>
(
foo
.
get
()),
foo
.
get
());
EXPECT_EQ
(
foo
->
GetLogitsName
(),
"foo_logits"
);
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/clear_dropout_component_transformer.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <string>
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component_transformation.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Transformer that removes dropout settings.
class
ClearDropoutComponentTransformer
:
public
ComponentTransformer
{
public:
// Implements ComponentTransformer.
tensorflow
::
Status
Transform
(
const
string
&
component_type
,
ComponentSpec
*
component_spec
)
override
{
for
(
FixedFeatureChannel
&
channel
:
*
component_spec
->
mutable_fixed_feature
())
{
channel
.
clear_dropout_id
();
channel
.
clear_dropout_keep_probability
();
}
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_COMPONENT_TRANSFORMER
(
ClearDropoutComponentTransformer
);
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/clear_dropout_component_transformer_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component_transformation.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Tests that a spec with no dropout features is unmodified.
TEST
(
ClearDropoutComponentTransformerTest
,
DoesNotModifyIfNoDropout
)
{
ComponentSpec
component_spec
;
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"foo"
);
component_spec
.
add_fixed_feature
()
->
set_name
(
"words"
);
const
ComponentSpec
expected_spec
=
component_spec
;
TF_ASSERT_OK
(
ComponentTransformer
::
ApplyAll
(
&
component_spec
));
EXPECT_THAT
(
component_spec
,
test
::
EqualsProto
(
expected_spec
));
}
// Tests that a spec with dropout features is modified.
TEST
(
ClearDropoutComponentTransformerTest
,
ClearsDropout
)
{
ComponentSpec
component_spec
;
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"foo"
);
FixedFeatureChannel
*
channel
=
component_spec
.
add_fixed_feature
();
channel
->
set_name
(
"words"
);
channel
->
set_dropout_id
(
100
);
channel
->
add_dropout_keep_probability
(
1.0
);
channel
->
add_dropout_keep_probability
(
0.5
);
channel
->
add_dropout_keep_probability
(
0.1
);
ComponentSpec
expected_spec
=
component_spec
;
expected_spec
.
clear_fixed_feature
();
expected_spec
.
add_fixed_feature
()
->
set_name
(
"words"
);
TF_ASSERT_OK
(
ComponentTransformer
::
ApplyAll
(
&
component_spec
));
EXPECT_THAT
(
component_spec
,
test
::
EqualsProto
(
expected_spec
));
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/component.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/component.h"
#include <memory>
#include <utility>
#include <vector>
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/str_util.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
string
GetNormalizedComponentBuilderName
(
const
ComponentSpec
&
component_spec
)
{
// The Python registration API is based on (relative) module paths, such as
// "some.module.FooComponent". Discard the module path prefix and use only
// the final segment, which is the subclass name.
const
std
::
vector
<
string
>
segments
=
tensorflow
::
str_util
::
Split
(
component_spec
.
component_builder
().
registered_name
(),
"."
);
CHECK_GT
(
segments
.
size
(),
0
)
<<
"No builder name for component spec: "
<<
component_spec
.
ShortDebugString
();
tensorflow
::
StringPiece
subclass_name
=
segments
.
back
();
// In addition, remove a "Builder" suffix, if any. In the Python codebase, a
// ComponentBuilder builds a TF graph to perform some computation, whereas in
// the runtime, a Component directly executes that computation.
tensorflow
::
str_util
::
ConsumeSuffix
(
&
subclass_name
,
"Builder"
);
return
subclass_name
.
ToString
();
}
tensorflow
::
Status
Component
::
Select
(
const
ComponentSpec
&
spec
,
string
*
result
)
{
const
string
normalized_builder_name
=
GetNormalizedComponentBuilderName
(
spec
);
// Iterate through all registered components, constructing them and querying
// their Supports() methods.
std
::
unique_ptr
<
Component
>
current_best
;
string
current_best_name
;
for
(
const
Registry
::
Registrar
*
component
=
registry
()
->
components
;
component
!=
nullptr
;
component
=
component
->
next
())
{
// component->object() is a function pointer to the subclass' constructor.
std
::
unique_ptr
<
Component
>
next
(
component
->
object
()());
string
next_name
(
component
->
name
());
if
(
!
next
->
Supports
(
spec
,
normalized_builder_name
))
{
continue
;
}
// First supported component.
if
(
current_best
==
nullptr
)
{
current_best
=
std
::
move
(
next
);
current_best_name
=
next_name
;
continue
;
}
// The two must agree on which takes precedence.
if
(
next
->
PreferredTo
(
*
current_best
))
{
if
(
current_best
->
PreferredTo
(
*
next
))
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Classes '"
,
current_best_name
,
"' and '"
,
next_name
,
"' both think they should be preferred to each-other. Please "
"add logic to their PreferredTo() methods to avoid this."
);
}
current_best
=
std
::
move
(
next
);
current_best_name
=
next_name
;
}
else
if
(
!
current_best
->
PreferredTo
(
*
next
))
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Classes '"
,
current_best_name
,
"' and '"
,
next_name
,
"' both think they should be dis-preferred to each-other. Please "
"add logic to their PreferredTo() methods to avoid this."
);
}
}
if
(
current_best
==
nullptr
)
{
return
tensorflow
::
errors
::
NotFound
(
"Could not find a best spec for component '"
,
spec
.
name
(),
"' with normalized builder name '"
,
normalized_builder_name
,
"'"
);
}
else
{
*
result
=
std
::
move
(
current_best_name
);
return
tensorflow
::
Status
::
OK
();
}
}
}
// namespace runtime
}
// namespace dragnn
REGISTER_SYNTAXNET_CLASS_REGISTRY
(
"DRAGNN Runtime Component"
,
dragnn
::
runtime
::
Component
);
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/component.h
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_COMPONENT_H_
#define DRAGNN_RUNTIME_COMPONENT_H_
#include <string>
#include "dragnn/core/compute_session.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "syntaxnet/registry.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// Helper method, currently only used by myelination.cc.
string
GetNormalizedComponentBuilderName
(
const
ComponentSpec
&
component_spec
);
// Interface for components.
class
Component
:
public
RegisterableClass
<
Component
>
{
public:
Component
(
const
Component
&
that
)
=
delete
;
Component
&
operator
=
(
const
Component
&
that
)
=
delete
;
virtual
~
Component
()
=
default
;
// Initializes this to the configuration in the |component_spec|. Retrieves
// pre-trained variables from the |variable_store|, which must outlive this.
// Adds layers and local operands to the |network_state_manager|, which must
// be positioned at the current component. Requests SessionState extensions
// from the |extension_manager|. On error, returns non-OK.
virtual
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
=
0
;
// Evaluates this on the |session_state| and |compute_session|, which must
// both be positioned at the current component. If |component_trace| is
// non-null, overwrites it with extracted traces. On error, returns non-OK.
virtual
tensorflow
::
Status
Evaluate
(
SessionState
*
session_state
,
ComputeSession
*
compute_session
,
ComponentTrace
*
component_trace
)
const
=
0
;
// Returns the best component for a spec, searching through all registered
// subclasses. This allows specialized implementations to be used.
//
// Sets |result| on success, otherwise returns an error message if a single
// best matching component could not be found. Returned statuses include:
// * OK: If a single best matching component was found.
// * FAILED_PRECONDITION: If an error occurred during the search.
// * NOT_FOUND: If the search was error-free, but no matches were found.
static
tensorflow
::
Status
Select
(
const
ComponentSpec
&
spec
,
string
*
result
);
protected:
Component
()
=
default
;
// Whether this component supports a given spec. |spec| is the full component
// spec and |normalized_builder_name| is the component builder name, with
// Python modules and the suffix "Builder" stripped.
virtual
bool
Supports
(
const
ComponentSpec
&
spec
,
const
string
&
normalized_builder_name
)
const
=
0
;
// Whether to prefer this component to another. (Both components must say that
// they support the spec.)
//
// Components must agree on whether they are more or less specialized than
// another component. Feel free to expose methods for subclasses to identify
// themselves; currently, we only have unoptimized implementations (which say
// they are never preferred) and optimized implementations (which say they are
// always preferred).
virtual
bool
PreferredTo
(
const
Component
&
other
)
const
=
0
;
private:
// Helps prevent use of the Create() method; use CreateOrError() instead.
using
RegisterableClass
<
Component
>::
Create
;
};
}
// namespace runtime
}
// namespace dragnn
DECLARE_SYNTAXNET_CLASS_REGISTRY
(
"DRAGNN Runtime Component"
,
dragnn
::
runtime
::
Component
);
}
// namespace syntaxnet
// Registers a subclass using its class name as a string.
#define DRAGNN_RUNTIME_REGISTER_COMPONENT(subclass) \
REGISTER_SYNTAXNET_CLASS_COMPONENT(::syntaxnet::dragnn::runtime::Component, \
#subclass, subclass)
#endif // DRAGNN_RUNTIME_COMPONENT_H_
research/syntaxnet/dragnn/runtime/component_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/component.h"
#include <memory>
#include <string>
#include "dragnn/core/compute_session.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/variable_store.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Expects that the two pointers have the same address.
void
ExpectSameAddress
(
const
void
*
pointer1
,
const
void
*
pointer2
)
{
EXPECT_EQ
(
pointer1
,
pointer2
);
}
// A trivial implementation for tests.
class
FooComponent
:
public
Component
{
public:
// Implements Component.
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
Evaluate
(
SessionState
*
session_state
,
ComputeSession
*
compute_session
,
ComponentTrace
*
component_trace
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
bool
Supports
(
const
ComponentSpec
&
spec
,
const
string
&
normalized_builder_name
)
const
override
{
return
normalized_builder_name
==
"FooComponent"
;
}
bool
PreferredTo
(
const
Component
&
other
)
const
override
{
return
false
;
}
};
DRAGNN_RUNTIME_REGISTER_COMPONENT
(
FooComponent
);
// Class that always says it's preferred.
class
ImTheBest1
:
public
FooComponent
{
public:
bool
Supports
(
const
ComponentSpec
&
spec
,
const
string
&
normalized_builder_name
)
const
override
{
return
normalized_builder_name
==
"ImTheBest"
;
}
bool
PreferredTo
(
const
Component
&
other
)
const
override
{
return
true
;
}
};
class
ImTheBest2
:
public
ImTheBest1
{};
DRAGNN_RUNTIME_REGISTER_COMPONENT
(
ImTheBest1
);
DRAGNN_RUNTIME_REGISTER_COMPONENT
(
ImTheBest2
);
// Class that always says it's dispreferred.
class
ImTheWorst1
:
public
FooComponent
{
public:
bool
Supports
(
const
ComponentSpec
&
spec
,
const
string
&
normalized_builder_name
)
const
override
{
return
normalized_builder_name
==
"ImTheWorst"
;
}
bool
PreferredTo
(
const
Component
&
other
)
const
override
{
return
false
;
}
};
class
ImTheWorst2
:
public
ImTheWorst1
{};
DRAGNN_RUNTIME_REGISTER_COMPONENT
(
ImTheWorst1
);
DRAGNN_RUNTIME_REGISTER_COMPONENT
(
ImTheWorst2
);
// Specialized foo implementation. We use debug-mode down-casting to check that
// the correct sub-class was instantiated.
class
SpecializedFooComponent
:
public
Component
{
public:
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
Evaluate
(
SessionState
*
session_state
,
ComputeSession
*
compute_session
,
ComponentTrace
*
component_trace
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
bool
Supports
(
const
ComponentSpec
&
spec
,
const
string
&
normalized_builder_name
)
const
override
{
return
normalized_builder_name
==
"FooComponent"
&&
spec
.
num_actions
()
==
1
;
}
bool
PreferredTo
(
const
Component
&
other
)
const
override
{
return
true
;
}
};
DRAGNN_RUNTIME_REGISTER_COMPONENT
(
SpecializedFooComponent
);
TEST
(
ComponentTest
,
NameResolutionError
)
{
ComponentSpec
component_spec
;
EXPECT_DEATH
(
GetNormalizedComponentBuilderName
(
component_spec
),
"No builder name for component spec"
);
}
// Tests that Python-esque module specifiers for builders are normalized
// appropriately.
TEST
(
ComponentTest
,
VariantsOfComponentBuilderNameResolve
)
{
for
(
const
string
&
registered_name
:
{
"FooComponent"
,
"FooComponentBuilder"
,
"module.FooComponent"
,
"module.FooComponentBuilder"
,
"some.long.path.to.module.FooComponent"
,
"some.long.path.to.module.FooComponentBuilder"
})
{
ComponentSpec
component_spec
;
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
registered_name
);
string
result
;
TF_ASSERT_OK
(
Component
::
Select
(
component_spec
,
&
result
));
EXPECT_EQ
(
result
,
"FooComponent"
);
}
}
TEST
(
ComponentTest
,
ErrorWithBothPreferred
)
{
ComponentSpec
component_spec
;
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"ImTheBest"
);
string
result
;
EXPECT_THAT
(
Component
::
Select
(
component_spec
,
&
result
),
test
::
IsErrorWithCodeAndSubstr
(
tensorflow
::
error
::
FAILED_PRECONDITION
,
"Classes 'ImTheBest2' and 'ImTheBest1' "
"both think they should be preferred to "
"each-other. Please add logic to their "
"PreferredTo() methods to avoid this."
));
}
TEST
(
ComponentTest
,
ErrorWithNeitherPreferred
)
{
ComponentSpec
component_spec
;
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"ImTheWorst"
);
string
result
;
EXPECT_THAT
(
Component
::
Select
(
component_spec
,
&
result
),
test
::
IsErrorWithCodeAndSubstr
(
tensorflow
::
error
::
FAILED_PRECONDITION
,
"Classes 'ImTheWorst2' and 'ImTheWorst1' both think they "
"should be dis-preferred to each-other. Please add logic to "
"their PreferredTo() methods to avoid this."
));
}
TEST
(
ComponentTest
,
DefaultComponent
)
{
ComponentSpec
component_spec
;
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"FooComponent"
);
component_spec
.
set_num_actions
(
45
);
string
result
;
TF_EXPECT_OK
(
Component
::
Select
(
component_spec
,
&
result
));
EXPECT_EQ
(
result
,
"FooComponent"
);
}
TEST
(
ComponentTest
,
SpecializedComponent
)
{
ComponentSpec
component_spec
;
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"FooComponent"
);
component_spec
.
set_num_actions
(
1
);
string
result
;
TF_EXPECT_OK
(
Component
::
Select
(
component_spec
,
&
result
));
EXPECT_EQ
(
result
,
"SpecializedFooComponent"
);
}
// Tests that Select() returns NOT_FOUND when there is no matching component.
TEST
(
ComponentTest
,
NoMatchingComponentNotFound
)
{
ComponentSpec
component_spec
;
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"unknown"
);
string
result
;
EXPECT_THAT
(
Component
::
Select
(
component_spec
,
&
result
),
test
::
IsErrorWithCodeAndSubstr
(
tensorflow
::
error
::
NOT_FOUND
,
"Could not find a best spec for component"
));
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/component_transformation.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/component_transformation.h"
#include <memory>
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "dragnn/runtime/component.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
tensorflow
::
Status
TransformComponents
(
const
string
&
input_master_spec_path
,
const
string
&
output_master_spec_path
)
{
MasterSpec
master_spec
;
TF_RETURN_IF_ERROR
(
tensorflow
::
ReadTextProto
(
tensorflow
::
Env
::
Default
(),
input_master_spec_path
,
&
master_spec
));
for
(
ComponentSpec
&
component_spec
:
*
master_spec
.
mutable_component
())
{
TF_RETURN_IF_ERROR
(
ComponentTransformer
::
ApplyAll
(
&
component_spec
));
}
return
tensorflow
::
WriteTextProto
(
tensorflow
::
Env
::
Default
(),
output_master_spec_path
,
master_spec
);
}
tensorflow
::
Status
ComponentTransformer
::
ApplyAll
(
ComponentSpec
*
component_spec
)
{
// Limit on the number of iterations, to prevent infinite loops.
static
constexpr
int
kMaxNumIterations
=
1000
;
std
::
set
<
string
>
names
;
// sorted for determinism
for
(
const
Registry
::
Registrar
*
registrar
=
registry
()
->
components
;
registrar
!=
nullptr
;
registrar
=
registrar
->
next
())
{
names
.
insert
(
registrar
->
name
());
}
std
::
vector
<
std
::
unique_ptr
<
ComponentTransformer
>>
transformers
;
transformers
.
reserve
(
names
.
size
());
for
(
const
string
&
name
:
names
)
transformers
.
emplace_back
(
Create
(
name
));
ComponentSpec
local_spec
=
*
component_spec
;
// avoid modification on error
for
(
int
iteration
=
0
;
iteration
<
kMaxNumIterations
;
++
iteration
)
{
const
ComponentSpec
original_spec
=
local_spec
;
for
(
const
auto
&
transformer
:
transformers
)
{
const
string
component_type
=
GetNormalizedComponentBuilderName
(
local_spec
);
TF_RETURN_IF_ERROR
(
transformer
->
Transform
(
component_type
,
&
local_spec
));
}
if
(
tensorflow
::
protobuf
::
util
::
MessageDifferencer
::
Equals
(
local_spec
,
original_spec
))
{
// Converged successfully; make modifications.
*
component_spec
=
local_spec
;
return
tensorflow
::
Status
::
OK
();
}
}
return
tensorflow
::
errors
::
Internal
(
"Failed to converge within "
,
kMaxNumIterations
,
" ComponentTransformer iterations"
);
}
}
// namespace runtime
}
// namespace dragnn
REGISTER_SYNTAXNET_CLASS_REGISTRY
(
"DRAGNN Runtime Component Transformer"
,
dragnn
::
runtime
::
ComponentTransformer
);
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/component_transformation.h
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Utils for transforming ComponentSpecs, typically (but not necessarily) in
// ways that are intended to improve speed. For example, a transformer might
// detect a favorable component configuration and replace a generic Component
// implementation with a faster version.
#ifndef DRAGNN_RUNTIME_COMPONENT_TRANSFORMATION_H_
#define DRAGNN_RUNTIME_COMPONENT_TRANSFORMATION_H_
#include <string>
#include "dragnn/protos/spec.pb.h"
#include "syntaxnet/base.h"
#include "syntaxnet/registry.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// Loads a MasterSpec from the |input_master_spec_path|, applies all registered
// ComponentTransformers to it (see ComponentTransformer::ApplyAll() below), and
// writes it to the |output_master_spec_path|. On error, returns non-OK.
//
// Side note: This function has a file-path-based API so it can be easily
// wrapped in a stand-alone binary.
tensorflow
::
Status
TransformComponents
(
const
string
&
input_master_spec_path
,
const
string
&
output_master_spec_path
);
// Interface for modules that can transform a ComponentSpec, which allows
// transformations to be plugged in on a decentralized basis.
class
ComponentTransformer
:
public
RegisterableClass
<
ComponentTransformer
>
{
public:
ComponentTransformer
(
const
ComponentTransformer
&
that
)
=
delete
;
ComponentTransformer
&
operator
=
(
const
ComponentTransformer
&
that
)
=
delete
;
virtual
~
ComponentTransformer
()
=
default
;
// Repeatedly loops through all registered transformers and applies them to
// the |component_spec| until no more changes occur. For determinism, each
// loop applies the transformers in ascending order of registered name. On
// error, returns non-OK and modifies nothing.
static
tensorflow
::
Status
ApplyAll
(
ComponentSpec
*
component_spec
);
protected:
ComponentTransformer
()
=
default
;
private:
// Helps prevent use of the Create() method.
using
RegisterableClass
<
ComponentTransformer
>::
Create
;
// Modifies the |component_spec|, which is currently configured to use the
// |component_type|, if compatible. On error, returns non-OK and modifies
// nothing. Note that it is not an error if the |component_spec| is simply
// not compatible with the desired transformation.
virtual
tensorflow
::
Status
Transform
(
const
string
&
component_type
,
ComponentSpec
*
component_spec
)
=
0
;
};
}
// namespace runtime
}
// namespace dragnn
DECLARE_SYNTAXNET_CLASS_REGISTRY
(
"DRAGNN Runtime Component Transformer"
,
dragnn
::
runtime
::
ComponentTransformer
);
}
// namespace syntaxnet
#define DRAGNN_RUNTIME_REGISTER_COMPONENT_TRANSFORMER(subclass) \
REGISTER_SYNTAXNET_CLASS_COMPONENT( \
::syntaxnet::dragnn::runtime::ComponentTransformer, #subclass, subclass)
#endif // DRAGNN_RUNTIME_COMPONENT_TRANSFORMATION_H_
research/syntaxnet/dragnn/runtime/component_transformation_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/component_transformation.h"
#include <string>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Transformer that fails if the component type is "fail".
class
MaybeFail
:
public
ComponentTransformer
{
public:
// Implements ComponentTransformer.
tensorflow
::
Status
Transform
(
const
string
&
component_type
,
ComponentSpec
*
)
override
{
if
(
component_type
==
"fail"
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Boom!"
);
}
return
tensorflow
::
Status
::
OK
();
}
};
DRAGNN_RUNTIME_REGISTER_COMPONENT_TRANSFORMER
(
MaybeFail
);
// Base class for transformers that change the name of the component, based on
// its current name.
class
ChangeNameBase
:
public
ComponentTransformer
{
public:
// Creates a transformer that changes the component name from |from| to |to|.
explicit
ChangeNameBase
(
const
string
&
from
,
const
string
&
to
)
:
from_
(
from
),
to_
(
to
)
{}
// Implements ComponentTransformer.
tensorflow
::
Status
Transform
(
const
string
&
,
ComponentSpec
*
component_spec
)
override
{
if
(
component_spec
->
name
()
==
from_
)
component_spec
->
set_name
(
to_
);
return
tensorflow
::
Status
::
OK
();
}
private:
// Component name to look for.
const
string
from_
;
// Component name to change to.
const
string
to_
;
};
// These will convert chain1 => chain2 => chain3.
class
Chain1To2
:
public
ChangeNameBase
{
public:
Chain1To2
()
:
ChangeNameBase
(
"chain1"
,
"chain2"
)
{}
};
class
Chain2To3
:
public
ChangeNameBase
{
public:
Chain2To3
()
:
ChangeNameBase
(
"chain2"
,
"chain3"
)
{}
};
// Adds "." to the name of the component, if it begins with "cycle".
class
Cycle
:
public
ComponentTransformer
{
public:
// Implements ComponentTransformer.
tensorflow
::
Status
Transform
(
const
string
&
,
ComponentSpec
*
component_spec
)
override
{
if
(
component_spec
->
name
().
substr
(
0
,
5
)
==
"cycle"
)
{
component_spec
->
mutable_name
()
->
append
(
"."
);
}
return
tensorflow
::
Status
::
OK
();
}
};
// Intentionally registered out of order to exercise sorting on registered name.
DRAGNN_RUNTIME_REGISTER_COMPONENT_TRANSFORMER
(
Chain1To2
);
DRAGNN_RUNTIME_REGISTER_COMPONENT_TRANSFORMER
(
Chain2To3
);
DRAGNN_RUNTIME_REGISTER_COMPONENT_TRANSFORMER
(
Cycle
);
// Arbitrary bogus path.
constexpr
char
kInvalidPath
[]
=
"path/to/some/invalid/file"
;
// Returns a unique temporary directory for tests.
string
GetUniqueTemporaryDir
()
{
static
int
counter
=
0
;
const
string
output_dir
=
tensorflow
::
io
::
JoinPath
(
tensorflow
::
testing
::
TmpDir
(),
tensorflow
::
strings
::
StrCat
(
"tmp_"
,
counter
++
));
TF_CHECK_OK
(
tensorflow
::
Env
::
Default
()
->
RecursivelyCreateDir
(
output_dir
));
return
output_dir
;
}
// Returns a MasterSpec parsed from the |text|.
MasterSpec
ParseSpec
(
const
string
&
text
)
{
MasterSpec
master_spec
;
CHECK
(
TextFormat
::
ParseFromString
(
text
,
&
master_spec
));
return
master_spec
;
}
// Tests that TransformComponents() fails if the input master spec path is
// invalid.
TEST
(
TransformComponentsTest
,
InvalidInputMasterSpecPath
)
{
const
string
temp_dir
=
GetUniqueTemporaryDir
();
const
string
output_path
=
tensorflow
::
io
::
JoinPath
(
temp_dir
,
"output"
);
EXPECT_FALSE
(
TransformComponents
(
kInvalidPath
,
output_path
).
ok
());
}
// Tests that TransformComponents() fails if the output master spec path is
// invalid.
TEST
(
TransformComponentsTest
,
InvalidOutputMasterSpecPath
)
{
const
string
temp_dir
=
GetUniqueTemporaryDir
();
const
string
input_path
=
tensorflow
::
io
::
JoinPath
(
temp_dir
,
"input"
);
const
MasterSpec
empty_spec
;
TF_ASSERT_OK
(
tensorflow
::
WriteTextProto
(
tensorflow
::
Env
::
Default
(),
input_path
,
empty_spec
));
EXPECT_FALSE
(
TransformComponents
(
input_path
,
kInvalidPath
).
ok
());
}
// Tests that TransformComponents() fails if one of the ComponentTransformers
// fails.
TEST
(
TransformComponentsTest
,
FailingComponentTransformer
)
{
const
string
temp_dir
=
GetUniqueTemporaryDir
();
const
string
input_path
=
tensorflow
::
io
::
JoinPath
(
temp_dir
,
"input"
);
const
string
output_path
=
tensorflow
::
io
::
JoinPath
(
temp_dir
,
"output"
);
const
MasterSpec
input_spec
=
ParseSpec
(
R"(
component {
component_builder { registered_name:'foo' }
}
component {
component_builder { registered_name:'fail' }
}
)"
);
TF_ASSERT_OK
(
tensorflow
::
WriteTextProto
(
tensorflow
::
Env
::
Default
(),
input_path
,
input_spec
));
EXPECT_THAT
(
TransformComponents
(
input_path
,
output_path
),
test
::
IsErrorWithSubstr
(
"Boom!"
));
}
// Tests that TransformComponents() properly applies all transformations.
TEST
(
TransformComponentsTest
,
Success
)
{
const
string
temp_dir
=
GetUniqueTemporaryDir
();
const
string
input_path
=
tensorflow
::
io
::
JoinPath
(
temp_dir
,
"input"
);
const
string
output_path
=
tensorflow
::
io
::
JoinPath
(
temp_dir
,
"output"
);
const
MasterSpec
input_spec
=
ParseSpec
(
R"(
component {
name:'chain1'
component_builder { registered_name:'foo' }
}
component {
name:'irrelevant'
component_builder { registered_name:'bar' }
}
)"
);
TF_ASSERT_OK
(
tensorflow
::
WriteTextProto
(
tensorflow
::
Env
::
Default
(),
input_path
,
input_spec
));
TF_ASSERT_OK
(
TransformComponents
(
input_path
,
output_path
));
MasterSpec
actual_spec
;
TF_ASSERT_OK
(
tensorflow
::
ReadTextProto
(
tensorflow
::
Env
::
Default
(),
output_path
,
&
actual_spec
));
const
MasterSpec
expected_spec
=
ParseSpec
(
R"(
component {
name:'chain3'
component_builder { registered_name:'foo' }
}
component {
name:'irrelevant'
component_builder { registered_name:'bar' }
}
)"
);
EXPECT_THAT
(
actual_spec
,
test
::
EqualsProto
(
expected_spec
));
}
// Tests that ComponentTransformer::ApplyAll() makes the expected modifications,
// including chained modifications.
TEST
(
ComponentTransformerTest
,
ApplyAllSuccess
)
{
ComponentSpec
component_spec
;
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"foo"
);
component_spec
.
set_name
(
"chain1"
);
ComponentSpec
modified_spec
=
component_spec
;
TF_ASSERT_OK
(
ComponentTransformer
::
ApplyAll
(
&
component_spec
));
modified_spec
.
set_name
(
"chain3"
);
EXPECT_THAT
(
component_spec
,
test
::
EqualsProto
(
modified_spec
));
}
// Tests that ComponentTransformer::ApplyAll() limits the number of iterations.
TEST
(
ComponentTransformerTest
,
ApplyAllLimitIterations
)
{
ComponentSpec
component_spec
;
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"foo"
);
component_spec
.
set_name
(
"cycle"
);
EXPECT_THAT
(
ComponentTransformer
::
ApplyAll
(
&
component_spec
),
test
::
IsErrorWithSubstr
(
"Failed to converge"
));
}
// Tests that ComponentTransformer::ApplyAll() fails if one of the
// ComponentTransformers fails.
TEST
(
ComponentTransformerTest
,
ApplyAllFailure
)
{
ComponentSpec
component_spec
;
component_spec
.
mutable_component_builder
()
->
set_registered_name
(
"fail"
);
EXPECT_THAT
(
ComponentTransformer
::
ApplyAll
(
&
component_spec
),
test
::
IsErrorWithSubstr
(
"Boom!"
));
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/conversion.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/conversion.h"
#include <memory>
#include <utility>
#include "dragnn/protos/runtime.pb.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/array_variable_store_builder.h"
#include "dragnn/runtime/master.h"
#include "dragnn/runtime/trained_model_variable_store.h"
#include "dragnn/runtime/variable_store.h"
#include "dragnn/runtime/variable_store_wrappers.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/env.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
tensorflow
::
Status
ConvertVariables
(
const
string
&
saved_model_dir
,
const
string
&
master_spec_path
,
const
string
&
variables_spec_path
,
const
string
&
variables_data_path
)
{
// Read the trained model.
auto
*
trained_model_store
=
new
TrainedModelVariableStore
();
std
::
unique_ptr
<
VariableStore
>
store
(
trained_model_store
);
TF_RETURN_IF_ERROR
(
trained_model_store
->
Reset
(
saved_model_dir
));
// Wrap the TF store to enable averaging and capturing.
//
// The averaging wrapper currently needs to allow fall-back versions, since
// derived parameters (used for the LSTM network) read averaged versions via
// their TensorFlow-internal logic.
//
// The capturing wrapper must be the outermost, so variable names, formats,
// and content are captured exactly as the components would receive them.
store
.
reset
(
new
TryAveragedVariableStoreWrapper
(
std
::
move
(
store
),
true
));
store
.
reset
(
new
FlexibleMatrixVariableStoreWrapper
(
std
::
move
(
store
)));
auto
*
capturing_store
=
new
CaptureUsedVariableStoreWrapper
(
std
::
move
(
store
));
store
.
reset
(
capturing_store
);
// Initialize a master using the wrapped store. This should populate the
// |capturing_store| with all of the used variables.
MasterSpec
master_spec
;
TF_RETURN_IF_ERROR
(
tensorflow
::
ReadTextProto
(
tensorflow
::
Env
::
Default
(),
master_spec_path
,
&
master_spec
));
Master
master
;
TF_RETURN_IF_ERROR
(
master
.
Initialize
(
master_spec
,
std
::
move
(
store
)));
// Convert the used variables into an ArrayVariableStore.
ArrayVariableStoreSpec
variables_spec
;
string
variables_data
;
TF_RETURN_IF_ERROR
(
ArrayVariableStoreBuilder
::
Build
(
capturing_store
->
variables
(),
&
variables_spec
,
&
variables_data
));
// Write the converted variables.
TF_RETURN_IF_ERROR
(
tensorflow
::
WriteTextProto
(
tensorflow
::
Env
::
Default
(),
variables_spec_path
,
variables_spec
));
TF_RETURN_IF_ERROR
(
tensorflow
::
WriteStringToFile
(
tensorflow
::
Env
::
Default
(),
variables_data_path
,
variables_data
));
return
tensorflow
::
Status
::
OK
();
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/conversion.h
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Utils for converting pre-trained models into a production-ready format.
#ifndef DRAGNN_RUNTIME_CONVERSION_H_
#define DRAGNN_RUNTIME_CONVERSION_H_
#include <string>
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// Converts selected variables from a pre-trained TF model into the format used
// by the ArrayVariableStore. Only converts the variables required to run the
// components in a given MasterSpec.
//
// Inputs:
// saved_model_dir: TF SavedModel directory.
// master_spec_path: Text-format MasterSpec proto.
//
// Outputs:
// variables_spec_path: Text-format ArrayVariableStoreSpec proto.
// variables_data_path: Byte array representing an ArrayVariableStore.
//
// This function will instantiate and initialize a Master using the MasterSpec
// at the |master_path|, so any registered components used by that MasterSpec
// must be linked into the binary.
//
// Side note: This function has a file-path-based API so it can be easily
// wrapped in a stand-alone binary.
tensorflow
::
Status
ConvertVariables
(
const
string
&
saved_model_dir
,
const
string
&
master_spec_path
,
const
string
&
variables_spec_path
,
const
string
&
variables_data_path
);
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_CONVERSION_H_
research/syntaxnet/dragnn/runtime/conversion_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/conversion.h"
#include <string>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/runtime.pb.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
class
ConvertVariablesTest
:
public
::
testing
::
Test
{
protected:
// The input files.
const
string
kSavedModelDir
=
tensorflow
::
io
::
JoinPath
(
test
::
GetTestDataPrefix
(),
"dragnn/runtime/testdata/rnn_tagger"
);
const
string
kMasterSpecPath
=
tensorflow
::
io
::
JoinPath
(
test
::
GetTestDataPrefix
(),
"dragnn/runtime/testdata/rnn_tagger/assets.extra/master_spec"
);
// Writable output files.
const
string
kVariablesSpecPath
=
tensorflow
::
io
::
JoinPath
(
tensorflow
::
testing
::
TmpDir
(),
"variables_spec"
);
const
string
kVariablesDataPath
=
tensorflow
::
io
::
JoinPath
(
tensorflow
::
testing
::
TmpDir
(),
"variables_data"
);
// Bogus file for tests.
const
string
kInvalidPath
=
"path/to/some/invalid/file"
;
// Expected output files.
const
string
kExpectedVariablesSpecPath
=
tensorflow
::
io
::
JoinPath
(
test
::
GetTestDataPrefix
(),
"dragnn/runtime/testdata/conversion_output_variables_spec"
);
const
string
kExpectedVariablesDataPath
=
tensorflow
::
io
::
JoinPath
(
test
::
GetTestDataPrefix
(),
"dragnn/runtime/testdata/conversion_output_variables_data"
);
// Local relative paths to the output files.
const
string
kLocalVariablesSpecPath
=
"dragnn/runtime/testdata/"
"conversion_output_variables_spec"
;
const
string
kLocalVariablesDataPath
=
"dragnn/runtime/testdata/"
"conversion_output_variables_data"
;
};
// Tests that the conversion fails if the saved model is invalid.
TEST_F
(
ConvertVariablesTest
,
InvalidSavedModel
)
{
EXPECT_FALSE
(
ConvertVariables
(
kInvalidPath
,
kMasterSpecPath
,
kVariablesSpecPath
,
kVariablesDataPath
)
.
ok
());
}
// Tests that the conversion fails if the master spec is invalid.
TEST_F
(
ConvertVariablesTest
,
InvalidMasterSpec
)
{
EXPECT_FALSE
(
ConvertVariables
(
kSavedModelDir
,
kInvalidPath
,
kVariablesSpecPath
,
kVariablesDataPath
)
.
ok
());
}
// Tests that the conversion fails if the variables spec is invalid.
TEST_F
(
ConvertVariablesTest
,
InvalidVariablesSpec
)
{
EXPECT_FALSE
(
ConvertVariables
(
kSavedModelDir
,
kMasterSpecPath
,
kInvalidPath
,
kVariablesDataPath
)
.
ok
());
}
// Tests that the conversion fails if the variables data is invalid.
TEST_F
(
ConvertVariablesTest
,
InvalidVariablesData
)
{
EXPECT_FALSE
(
ConvertVariables
(
kSavedModelDir
,
kMasterSpecPath
,
kVariablesSpecPath
,
kInvalidPath
)
.
ok
());
}
// Tests that the conversion succeeds on the pre-trained inputs and reproduces
// expected outputs.
TEST_F
(
ConvertVariablesTest
,
RegressionTest
)
{
TF_EXPECT_OK
(
ConvertVariables
(
kSavedModelDir
,
kMasterSpecPath
,
kVariablesSpecPath
,
kVariablesDataPath
));
ArrayVariableStoreSpec
actual_variables_spec
;
string
actual_variables_data
;
TF_ASSERT_OK
(
tensorflow
::
ReadTextProto
(
tensorflow
::
Env
::
Default
(),
kVariablesSpecPath
,
&
actual_variables_spec
));
TF_ASSERT_OK
(
tensorflow
::
ReadFileToString
(
tensorflow
::
Env
::
Default
(),
kVariablesDataPath
,
&
actual_variables_data
));
if
(
false
)
{
TF_ASSERT_OK
(
tensorflow
::
WriteTextProto
(
tensorflow
::
Env
::
Default
(),
kLocalVariablesSpecPath
,
actual_variables_spec
));
TF_ASSERT_OK
(
tensorflow
::
WriteStringToFile
(
tensorflow
::
Env
::
Default
(),
kLocalVariablesDataPath
,
actual_variables_data
));
}
else
{
ArrayVariableStoreSpec
expected_variables_spec
;
string
expected_variables_data
;
TF_ASSERT_OK
(
tensorflow
::
ReadTextProto
(
tensorflow
::
Env
::
Default
(),
kExpectedVariablesSpecPath
,
&
expected_variables_spec
));
TF_ASSERT_OK
(
tensorflow
::
ReadFileToString
(
tensorflow
::
Env
::
Default
(),
kExpectedVariablesDataPath
,
&
expected_variables_data
));
EXPECT_THAT
(
actual_variables_spec
,
test
::
EqualsProto
(
expected_variables_spec
));
EXPECT_EQ
(
actual_variables_data
,
expected_variables_data
);
}
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/converter.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Tool for converting trained models for use in the runtime.
#include <set>
#include <string>
#include <vector>
#include "dragnn/runtime/component_transformation.h"
#include "dragnn/runtime/conversion.h"
#include "dragnn/runtime/myelin/myelination.h"
#include "dragnn/runtime/xla/xla_compilation.h"
#include "syntaxnet/base.h"
#include "sling/base/flags.h" // TF does not support flags, but SLING does
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
DEFINE_string
(
saved_model_dir
,
""
,
"Path to TF SavedModel directory."
);
DEFINE_string
(
master_spec_file
,
""
,
"Path to text-format MasterSpec proto."
);
DEFINE_string
(
myelin_components
,
""
,
"Comma-delimited list of components to compile using Myelin, if any"
);
DEFINE_string
(
xla_components
,
""
,
"Comma-delimited list of components to compile using XLA, if any."
);
DEFINE_string
(
xla_model_name
,
""
,
"Name to apply to XLA-based components."
);
DEFINE_string
(
output_dir
,
""
,
"Path to an output directory. This will be filled with the following "
"files and subdirectories. MasterSpec: Converted text-format MasterSpec "
"proto. ArrayVariableStoreSpec: Converted text-format variable spec. "
"ArrayVariableStoreData: Converted variable data. myelin/*: Compiled "
"Myelin components, if any. xla/*: Compiled XLA components, if any."
);
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Splits the |list| on commas and returns the set of elements.
std
::
set
<
string
>
Split
(
const
string
&
list
)
{
const
std
::
vector
<
string
>
elements
=
tensorflow
::
str_util
::
Split
(
list
,
","
);
return
std
::
set
<
string
>
(
elements
.
begin
(),
elements
.
end
());
}
// Creates an empty directory at the |path|. If the directory exists, it is
// recursively deleted first.
void
CreateEmptyDir
(
const
string
&
path
)
{
// Ensure that the directory exists; otherwise DeleteRecursively() may fail.
TF_QCHECK_OK
(
tensorflow
::
Env
::
Default
()
->
RecursivelyCreateDir
(
path
));
int64
unused_undeleted_files
,
unused_undeleted_dirs
;
TF_QCHECK_OK
(
tensorflow
::
Env
::
Default
()
->
DeleteRecursively
(
path
,
&
unused_undeleted_files
,
&
unused_undeleted_dirs
));
TF_QCHECK_OK
(
tensorflow
::
Env
::
Default
()
->
RecursivelyCreateDir
(
path
));
}
// Performs Myelin compilation on the MasterSpec at |master_spec_path|, if
// requested. Returns the path to the converted or original MasterSpec.
string
CompileMyelin
(
const
string
&
master_spec_path
)
{
const
std
::
set
<
string
>
components
=
Split
(
FLAGS_myelin_components
);
if
(
components
.
empty
())
return
master_spec_path
;
LOG
(
INFO
)
<<
"Compiling Myelin in MasterSpec "
<<
master_spec_path
;
const
string
dir
=
tensorflow
::
io
::
JoinPath
(
FLAGS_output_dir
,
"myelin"
);
CreateEmptyDir
(
dir
);
TF_QCHECK_OK
(
MyelinateCells
(
FLAGS_saved_model_dir
,
master_spec_path
,
components
,
dir
));
return
tensorflow
::
io
::
JoinPath
(
dir
,
"master-spec"
);
}
// Performs XLA compilation on the MasterSpec at |master_spec_path|, if
// requested. Returns the path to the converted or original MasterSpec.
string
CompileXla
(
const
string
&
master_spec_path
)
{
const
std
::
set
<
string
>
components
=
Split
(
FLAGS_xla_components
);
if
(
components
.
empty
())
return
master_spec_path
;
LOG
(
INFO
)
<<
"Compiling XLA in MasterSpec "
<<
master_spec_path
;
const
string
dir
=
tensorflow
::
io
::
JoinPath
(
FLAGS_output_dir
,
"xla"
);
CreateEmptyDir
(
dir
);
TF_QCHECK_OK
(
XlaCompileCells
(
FLAGS_saved_model_dir
,
master_spec_path
,
components
,
FLAGS_xla_model_name
,
dir
));
return
tensorflow
::
io
::
JoinPath
(
dir
,
"master-spec"
);
}
// Transforms the MasterSpec at |master_spec_path|, and returns the path to the
// transformed MasterSpec.
string
Transform
(
const
string
&
master_spec_path
)
{
LOG
(
INFO
)
<<
"Transforming MasterSpec "
<<
master_spec_path
;
const
string
output_master_spec_path
=
tensorflow
::
io
::
JoinPath
(
FLAGS_output_dir
,
"MasterSpec"
);
TF_QCHECK_OK
(
TransformComponents
(
master_spec_path
,
output_master_spec_path
));
return
output_master_spec_path
;
}
// Performs final variable conversion on the MasterSpec at |master_spec_path|.
void
Convert
(
const
string
&
master_spec_path
)
{
LOG
(
INFO
)
<<
"Converting MasterSpec "
<<
master_spec_path
;
const
string
variables_data_path
=
tensorflow
::
io
::
JoinPath
(
FLAGS_output_dir
,
"ArrayVariableStoreData"
);
const
string
variables_spec_path
=
tensorflow
::
io
::
JoinPath
(
FLAGS_output_dir
,
"ArrayVariableStoreSpec"
);
TF_QCHECK_OK
(
ConvertVariables
(
FLAGS_saved_model_dir
,
master_spec_path
,
variables_spec_path
,
variables_data_path
));
}
// Implements main().
void
Main
()
{
CreateEmptyDir
(
FLAGS_output_dir
);
string
master_spec_path
=
FLAGS_master_spec_file
;
master_spec_path
=
CompileMyelin
(
master_spec_path
);
master_spec_path
=
CompileXla
(
master_spec_path
);
master_spec_path
=
Transform
(
master_spec_path
);
Convert
(
master_spec_path
);
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
int
main
(
int
argc
,
char
**
argv
)
{
sling
::
Flag
::
ParseCommandLineFlags
(
&
argc
,
argv
,
true
);
syntaxnet
::
dragnn
::
runtime
::
Main
();
return
0
;
}
research/syntaxnet/dragnn/runtime/converter_test.sh
deleted
100755 → 0
View file @
a4bb31d0
#!/bin/bash
# Copyright 2018 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Test for converter tool. To update the testdata, run the test with a single
# command-line argument specifying the path to the testdata directory.
set
-e
set
-u
# Infer the location of the data dependencies.
if
[[
-d
"
${
BASH_SOURCE
[0]
}
.runfiles"
]]
;
then
# Use the ".runfiles" directory if available (this typically happens when
# running manually). SyntaxNet does not specify a workspace name, so the
# runfiles are placed in ".runfiles/__main__". If SyntaxNet is configured
# with a workspace name, then change "__main__" to that name. See
# https://github.com/bazelbuild/bazel/wiki/Updating-the-runfiles-tree-structure
RUNFILES
=
"
${
BASH_SOURCE
[0]
}
.runfiles/__main__"
else
# Otherwise, use this recipe borrowed from
# https://github.com/bazelbuild/bazel/blob/7d265e07e7a1e37f04d53342710e4f21d9ee8083/examples/shell/test.sh#L21
# shellcheck disable=SC2091
RUNFILES
=
"
${
RUNFILES
:-
"
$(
"
$(
cd
"
$(
dirname
"
${
BASH_SOURCE
[0]
}
"
)
"
)
"
;
pwd
)
"
}
"
fi
readonly
RUNFILES
readonly
RUNTIME
=
"
${
RUNFILES
}
/dragnn/runtime"
readonly
CONVERTER
=
"
${
RUNTIME
}
/converter"
readonly
SAVED_MODEL
=
"
${
RUNTIME
}
/testdata/rnn_tagger"
readonly
MASTER_SPEC
=
"
${
SAVED_MODEL
}
/assets.extra/master_spec"
readonly
EXPECTED
=
"
${
RUNTIME
}
/testdata/converter_output"
readonly
OUTPUT
=
"
${
TEST_TMPDIR
:-
/tmp/
$$
}
/converted"
# Fails the test with a message.
function
fail
()
{
echo
"
$@
"
1>&2
# print to stderr
exit
1
}
# Asserts that a file exists.
function
assert_file_exists
()
{
if
[[
!
-f
"
$1
"
]]
;
then
fail
"missing file:
$1
"
fi
}
# Asserts that two files have the same content.
function
assert_file_content_eq
()
{
assert_file_exists
"
$1
"
assert_file_exists
"
$2
"
if
!
diff
-u
"
$1
"
"
$2
"
;
then
fail
"files differ:
$1
$2
"
fi
}
rm
-rf
"
${
OUTPUT
}
"
"
${
CONVERTER
}
"
\
--saved_model_dir
=
"
${
SAVED_MODEL
}
"
\
--master_spec_file
=
"
${
MASTER_SPEC
}
"
\
--output_dir
=
"
${
OUTPUT
}
"
\
--logtostderr
for
file
in
\
'MasterSpec'
\
'ArrayVariableStoreData'
\
'ArrayVariableStoreSpec'
;
do
if
[[
$#
-gt
0
]]
;
then
# Update expected output.
rm
-f
"
$1
/
${
file
}
"
cp
-f
"
${
OUTPUT
}
/
${
file
}
"
"
$1
/
${
file
}
"
else
# Compare to expected output.
assert_file_content_eq
"
${
OUTPUT
}
/
${
file
}
"
"
${
EXPECTED
}
/
${
file
}
"
fi
done
rm
-rf
"
${
OUTPUT
}
"
research/syntaxnet/dragnn/runtime/dynamic_component.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// The DynamicComponent is the runtime analogue of the DynamicComponentBuilder
// in the Python codebase. The role of the DynamicComponent is to manage the
// loop over transition steps, including:
// * Allocating stepwise memory for network states and operands.
// * Performing some computation at each step.
// * Advancing the transition state until terminal.
//
// Note that the number of transition taken on any given evaluation of the
// DynamicComponent cannot be determined in advance.
//
// The core computational work is delegated to a NetworkUnit, which is evaluated
// at each transition step. This makes the DynamicComponent flexible, since it
// can be applied to any NetworkUnit implementation, but it can be significantly
// more efficient to use a task-specific component implementation. For example,
// the "shift-only" transition system merely scans the input tokens, and in that
// case we could replace the incremental loop with a "bulk" computation.
#include <stddef.h>
#include <string>
#include <vector>
#include "dragnn/core/compute_session.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/network_unit.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/transition_system_traits.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Performs an incremental computation, one transition at a time.
class
DynamicComponent
:
public
Component
{
protected:
// Implements Component.
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
override
;
tensorflow
::
Status
Evaluate
(
SessionState
*
session_state
,
ComputeSession
*
compute_session
,
ComponentTrace
*
component_trace
)
const
override
;
// This class is intended to support all DynamicComponent layers. We currently
// prefer to return `true` here and throw errors in Initialize() if a
// particular feature is not supported.
bool
Supports
(
const
ComponentSpec
&
spec
,
const
string
&
normalized_builder_name
)
const
override
{
return
normalized_builder_name
==
"DynamicComponent"
;
}
// This class is not optimized, so any other supported subclasses of Component
// should be preferred.
bool
PreferredTo
(
const
Component
&
other
)
const
override
{
return
false
;
}
private:
// Name of this component.
string
name_
;
// Network unit that produces logits.
std
::
unique_ptr
<
NetworkUnit
>
network_unit_
;
// Whether the transition system is deterministic.
bool
deterministic_
=
false
;
// Handle to the network unit logits. Valid iff |deterministic_| is false.
LayerHandle
<
float
>
logits_handle_
;
};
tensorflow
::
Status
DynamicComponent
::
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
{
name_
=
component_spec
.
name
();
if
(
!
component_spec
.
attention_component
().
empty
())
{
return
tensorflow
::
errors
::
Unimplemented
(
"Attention is not supported"
);
}
TF_RETURN_IF_ERROR
(
NetworkUnit
::
CreateOrError
(
NetworkUnit
::
GetClassName
(
component_spec
),
&
network_unit_
));
TF_RETURN_IF_ERROR
(
network_unit_
->
Initialize
(
component_spec
,
variable_store
,
network_state_manager
,
extension_manager
));
// Logits are unnecesssary when the component is deterministic.
deterministic_
=
TransitionSystemTraits
(
component_spec
).
is_deterministic
;
if
(
!
deterministic_
)
{
const
string
logits_name
=
network_unit_
->
GetLogitsName
();
if
(
logits_name
.
empty
())
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Network unit does not produce logits: "
,
component_spec
.
network_unit
().
ShortDebugString
());
}
size_t
dimension
=
0
;
TF_RETURN_IF_ERROR
(
network_state_manager
->
LookupLayer
(
name_
,
logits_name
,
&
dimension
,
&
logits_handle_
));
if
(
dimension
!=
component_spec
.
num_actions
())
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Dimension mismatch between network unit logits ("
,
dimension
,
") and ComponentSpec.num_actions ("
,
component_spec
.
num_actions
(),
") in component '"
,
name_
,
"'"
);
}
}
return
tensorflow
::
Status
::
OK
();
}
// No batches or beams.
constexpr
int
kNumItems
=
1
;
tensorflow
::
Status
DynamicComponent
::
Evaluate
(
SessionState
*
session_state
,
ComputeSession
*
compute_session
,
ComponentTrace
*
component_trace
)
const
{
NetworkStates
&
network_states
=
session_state
->
network_states
;
for
(
size_t
step_index
=
0
;
!
compute_session
->
IsTerminal
(
name_
);
++
step_index
)
{
network_states
.
AddStep
();
TF_RETURN_IF_ERROR
(
network_unit_
->
Evaluate
(
step_index
,
session_state
,
compute_session
));
// If the component is deterministic, take the oracle transition instead of
// predicting the next transition using the logits.
if
(
deterministic_
)
{
compute_session
->
AdvanceFromOracle
(
name_
);
}
else
{
// AddStep() may invalidate the logits (due to reallocation), so the layer
// lookup cannot be hoisted out of this loop.
const
Vector
<
float
>
logits
(
network_states
.
GetLayer
(
logits_handle_
).
row
(
step_index
));
if
(
!
compute_session
->
AdvanceFromPrediction
(
name_
,
logits
.
data
(),
kNumItems
,
logits
.
size
()))
{
return
tensorflow
::
errors
::
Internal
(
"Error in ComputeSession::AdvanceFromPrediction()"
);
}
}
}
return
tensorflow
::
Status
::
OK
();
}
DRAGNN_RUNTIME_REGISTER_COMPONENT
(
DynamicComponent
);
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/dynamic_component_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <algorithm>
#include <limits>
#include <memory>
#include <string>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/network_unit.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include <gmock/gmock.h>
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
using
::
testing
::
_
;
using
::
testing
::
Return
;
constexpr
size_t
kStepsDim
=
41
;
constexpr
size_t
kNumSteps
=
23
;
// Fills each row of its logits with the step index.
class
StepsNetwork
:
public
NetworkUnit
{
public:
// Implements NetworkUnit.
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
component_spec
,
VariableStore
*
variable_store
,
NetworkStateManager
*
network_state_manager
,
ExtensionManager
*
extension_manager
)
override
{
return
network_state_manager
->
AddLayer
(
"steps"
,
kStepsDim
,
&
handle_
);
}
string
GetLogitsName
()
const
override
{
return
"steps"
;
}
tensorflow
::
Status
Evaluate
(
size_t
step_index
,
SessionState
*
session_state
,
ComputeSession
*
compute_session
)
const
override
{
const
MutableVector
<
float
>
logits
=
session_state
->
network_states
.
GetLayer
(
handle_
).
row
(
step_index
);
for
(
float
&
logit
:
logits
)
logit
=
step_index
;
return
tensorflow
::
Status
::
OK
();
}
private:
// Handle to the logits layer.
LayerHandle
<
float
>
handle_
;
};
DRAGNN_RUNTIME_REGISTER_NETWORK_UNIT
(
StepsNetwork
);
// As above, but does not report a logits layer.
class
NoLogitsNetwork
:
public
StepsNetwork
{
public:
// Implements NetworkUnit.
string
GetLogitsName
()
const
override
{
return
""
;
}
};
DRAGNN_RUNTIME_REGISTER_NETWORK_UNIT
(
NoLogitsNetwork
);
class
DynamicComponentTest
:
public
NetworkTestBase
{
protected:
// Creates a component, initializes it based on the |component_spec_text| and
// |network_unit_name|, and evaluates it. On error, returns non-OK.
tensorflow
::
Status
Run
(
const
string
&
component_spec_text
,
const
string
&
network_unit_name
)
{
ComponentSpec
component_spec
;
CHECK
(
TextFormat
::
ParseFromString
(
component_spec_text
,
&
component_spec
));
component_spec
.
set_name
(
kTestComponentName
);
component_spec
.
mutable_network_unit
()
->
set_registered_name
(
network_unit_name
);
// Neither DynamicComponent nor the test networks use linked embeddings, so
// a trivial network suffices.
AddComponent
(
kTestComponentName
);
TF_RETURN_IF_ERROR
(
Component
::
CreateOrError
(
"DynamicComponent"
,
&
component_
));
TF_RETURN_IF_ERROR
(
component_
->
Initialize
(
component_spec
,
&
variable_store_
,
&
network_state_manager_
,
&
extension_manager_
));
network_states_
.
Reset
(
&
network_state_manager_
);
StartComponent
(
0
);
// DynamicComponent will add steps
session_state_
.
extensions
.
Reset
(
&
extension_manager_
);
TF_RETURN_IF_ERROR
(
component_
->
Evaluate
(
&
session_state_
,
&
compute_session_
,
nullptr
));
steps_
=
GetLayer
(
kTestComponentName
,
"steps"
);
return
tensorflow
::
Status
::
OK
();
}
std
::
unique_ptr
<
Component
>
component_
;
Matrix
<
float
>
steps_
;
};
// Tests that DynamicComponent fails if the spec uses attention.
TEST_F
(
DynamicComponentTest
,
UnsupportedAttention
)
{
EXPECT_THAT
(
Run
(
"attention_component: 'foo'"
,
"NoLogitsNetwork"
),
test
::
IsErrorWithSubstr
(
"Attention is not supported"
));
}
// Tests that DynamicComponent fails if the network does not produce logits.
TEST_F
(
DynamicComponentTest
,
NoLogits
)
{
EXPECT_THAT
(
Run
(
""
,
"NoLogitsNetwork"
),
test
::
IsErrorWithSubstr
(
"Network unit does not produce logits"
));
}
// Tests that DynamicComponent fails if the logits do not have the required
// dimension.
TEST_F
(
DynamicComponentTest
,
MismatchedLogitsDimension
)
{
EXPECT_THAT
(
Run
(
"num_actions: 42"
,
"StepsNetwork"
),
test
::
IsErrorWithSubstr
(
"Dimension mismatch between network unit logits "
"(41) and ComponentSpec.num_actions (42)"
));
}
// Tests that DynamicComponent fails if ComputeSession::AdvanceFromPrediction()
// returns false.
TEST_F
(
DynamicComponentTest
,
FailToAdvanceFromPrediction
)
{
EXPECT_CALL
(
compute_session_
,
IsTerminal
(
_
)).
WillRepeatedly
(
Return
(
false
));
EXPECT_CALL
(
compute_session_
,
AdvanceFromPrediction
(
_
,
_
,
_
,
_
))
.
WillOnce
(
Return
(
false
));
EXPECT_THAT
(
Run
(
"num_actions: 41"
,
"StepsNetwork"
),
test
::
IsErrorWithSubstr
(
"Error in ComputeSession::AdvanceFromPrediction()"
));
}
// Tests that DynamicComponent evaluates its network unit once per transition,
// each time passing the proper step index.
TEST_F
(
DynamicComponentTest
,
Steps
)
{
SetupTransitionLoop
(
kNumSteps
);
// Accept |num_steps| transition steps.
EXPECT_CALL
(
compute_session_
,
AdvanceFromPrediction
(
_
,
_
,
_
,
_
))
.
Times
(
kNumSteps
)
.
WillRepeatedly
(
Return
(
true
));
TF_ASSERT_OK
(
Run
(
"num_actions: 41"
,
"StepsNetwork"
));
ASSERT_EQ
(
steps_
.
num_rows
(),
kNumSteps
);
for
(
size_t
step_index
=
0
;
step_index
<
kNumSteps
;
++
step_index
)
{
ExpectVector
(
steps_
.
row
(
step_index
),
kStepsDim
,
step_index
);
}
}
// Tests that DynamicComponent calls ComputeSession::AdvanceFromOracle() and
// does not use logits when the component is deterministic.
TEST_F
(
DynamicComponentTest
,
Determinstic
)
{
SetupTransitionLoop
(
kNumSteps
);
// Take the oracle transition instead of predicting from logits.
EXPECT_CALL
(
compute_session_
,
AdvanceFromOracle
(
_
)).
Times
(
kNumSteps
);
TF_EXPECT_OK
(
Run
(
"num_actions: 1"
,
"NoLogitsNetwork"
));
// The NoLogitsNetwork still produces the "steps" layer, even if it does not
// mark them as its logits.
ASSERT_EQ
(
steps_
.
num_rows
(),
kNumSteps
);
for
(
size_t
step_index
=
0
;
step_index
<
kNumSteps
;
++
step_index
)
{
ExpectVector
(
steps_
.
row
(
step_index
),
kStepsDim
,
step_index
);
}
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
Prev
1
2
3
4
5
6
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment