Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
edea2b67
Commit
edea2b67
authored
May 11, 2018
by
Terry Koo
Browse files
Remove runtime because reasons.
parent
a4bb31d0
Changes
291
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
2239 deletions
+0
-2239
research/syntaxnet/dragnn/runtime/testdata/rnn_tagger/assets/master_spec
...net/dragnn/runtime/testdata/rnn_tagger/assets/master_spec
+0
-154
research/syntaxnet/dragnn/runtime/testdata/rnn_tagger/saved_model.pb
...ntaxnet/dragnn/runtime/testdata/rnn_tagger/saved_model.pb
+0
-0
research/syntaxnet/dragnn/runtime/testdata/rnn_tagger/variables/variables.data-00000-of-00001
...stdata/rnn_tagger/variables/variables.data-00000-of-00001
+0
-0
research/syntaxnet/dragnn/runtime/testdata/rnn_tagger/variables/variables.index
...gnn/runtime/testdata/rnn_tagger/variables/variables.index
+0
-0
research/syntaxnet/dragnn/runtime/testdata/ten_bytes
research/syntaxnet/dragnn/runtime/testdata/ten_bytes
+0
-1
research/syntaxnet/dragnn/runtime/trained_model.cc
research/syntaxnet/dragnn/runtime/trained_model.cc
+0
-119
research/syntaxnet/dragnn/runtime/trained_model.h
research/syntaxnet/dragnn/runtime/trained_model.h
+0
-75
research/syntaxnet/dragnn/runtime/trained_model_test.cc
research/syntaxnet/dragnn/runtime/trained_model_test.cc
+0
-132
research/syntaxnet/dragnn/runtime/trained_model_variable_store.cc
.../syntaxnet/dragnn/runtime/trained_model_variable_store.cc
+0
-192
research/syntaxnet/dragnn/runtime/trained_model_variable_store.h
...h/syntaxnet/dragnn/runtime/trained_model_variable_store.h
+0
-82
research/syntaxnet/dragnn/runtime/trained_model_variable_store_test.cc
...axnet/dragnn/runtime/trained_model_variable_store_test.cc
+0
-384
research/syntaxnet/dragnn/runtime/transition_system_traits.cc
...arch/syntaxnet/dragnn/runtime/transition_system_traits.cc
+0
-87
research/syntaxnet/dragnn/runtime/transition_system_traits.h
research/syntaxnet/dragnn/runtime/transition_system_traits.h
+0
-55
research/syntaxnet/dragnn/runtime/transition_system_traits_test.cc
...syntaxnet/dragnn/runtime/transition_system_traits_test.cc
+0
-156
research/syntaxnet/dragnn/runtime/type_keyed_set.h
research/syntaxnet/dragnn/runtime/type_keyed_set.h
+0
-106
research/syntaxnet/dragnn/runtime/type_keyed_set_test.cc
research/syntaxnet/dragnn/runtime/type_keyed_set_test.cc
+0
-122
research/syntaxnet/dragnn/runtime/unicode_dictionary.cc
research/syntaxnet/dragnn/runtime/unicode_dictionary.cc
+0
-93
research/syntaxnet/dragnn/runtime/unicode_dictionary.h
research/syntaxnet/dragnn/runtime/unicode_dictionary.h
+0
-122
research/syntaxnet/dragnn/runtime/unicode_dictionary_test.cc
research/syntaxnet/dragnn/runtime/unicode_dictionary_test.cc
+0
-161
research/syntaxnet/dragnn/runtime/variable_store.h
research/syntaxnet/dragnn/runtime/variable_store.h
+0
-198
No files found.
Too many changes to show.
To preserve performance only
291 of 291+
files are displayed.
Plain diff
Email patch
research/syntaxnet/dragnn/runtime/testdata/rnn_tagger/assets/master_spec
deleted
100644 → 0
View file @
a4bb31d0
component {
name: "rnn"
transition_system {
registered_name: "shift-only"
parameters {
key: "left_to_right"
value: "false"
}
parameters {
key: "parser_skip_deterministic"
value: "false"
}
}
resource {
name: "words-embedding-input"
part {
file_pattern: "resources/component_0_rnn/resource_0_words-embedding-input/part_0"
file_format: "tf-records"
record_format: "syntaxnet.TokenEmbedding"
}
}
resource {
name: "words-vocab-input"
part {
file_pattern: "resources/component_0_rnn/resource_1_words-vocab-input/part_0"
file_format: "text"
record_format: ""
}
}
resource {
name: "char-ngram-map"
part {
file_pattern: "resources/component_0_rnn/resource_2_char-ngram-map/part_0"
file_format: "text"
record_format: ""
}
}
resource {
name: "word-map"
part {
file_pattern: "resources/component_0_rnn/resource_3_word-map/part_0"
file_format: "text"
record_format: ""
}
}
resource {
name: "label-map"
part {
file_pattern: "resources/component_0_rnn/resource_4_label-map/part_0"
file_format: "text"
record_format: ""
}
}
fixed_feature {
name: "char_ngrams"
fml: "input.token { offset(-1).char-ngram(min-length=1,max-length=3,mark-boundaries=true) offset(0).char-ngram(min-length=1,max-length=3,mark-boundaries=true) offset(1).char-ngram(min-length=1,max-length=3,mark-boundaries=true) }"
embedding_dim: 32
vocabulary_size: 25788
size: 3
}
fixed_feature {
name: "words"
fml: "input.token.word(min-freq=2)"
embedding_dim: 64
vocabulary_size: 23769
size: 1
}
network_unit {
registered_name: "LSTMNetwork"
parameters {
key: "hidden_layer_sizes"
value: "128"
}
parameters {
key: "omit_logits"
value: "true"
}
}
backend {
registered_name: "SyntaxNetComponent"
}
num_actions: 1
attention_component: ""
component_builder {
registered_name: "DynamicComponentBuilder"
}
}
component {
name: "tagger"
transition_system {
registered_name: "tagger"
parameters {
key: "parser_skip_deterministic"
value: "false"
}
}
resource {
name: "tag-map"
part {
file_pattern: "resources/component_1_tagger/resource_0_tag-map/part_0"
file_format: "text"
record_format: ""
}
}
resource {
name: "tag-to-category"
part {
file_pattern: "resources/component_1_tagger/resource_1_tag-to-category/part_0"
file_format: "text"
record_format: ""
}
}
resource {
name: "label-map"
part {
file_pattern: "resources/component_0_rnn/resource_4_label-map/part_0"
file_format: "text"
record_format: ""
}
}
linked_feature {
name: "recurrence"
fml: "bias(0)"
embedding_dim: 32
size: 1
source_component: "tagger"
source_translator: "history"
source_layer: "layer_0"
}
linked_feature {
name: "rnn"
fml: "input.focus"
embedding_dim: -1
size: 1
source_component: "rnn"
source_translator: "reverse-token"
source_layer: "layer_0"
}
network_unit {
registered_name: "FeedForwardNetwork"
parameters {
key: "hidden_layer_sizes"
value: "64,64"
}
}
backend {
registered_name: "SyntaxNetComponent"
}
num_actions: 45
attention_component: ""
component_builder {
registered_name: "DynamicComponentBuilder"
}
}
research/syntaxnet/dragnn/runtime/testdata/rnn_tagger/saved_model.pb
deleted
100644 → 0
View file @
a4bb31d0
File deleted
research/syntaxnet/dragnn/runtime/testdata/rnn_tagger/variables/variables.data-00000-of-00001
deleted
100644 → 0
View file @
a4bb31d0
File deleted
research/syntaxnet/dragnn/runtime/testdata/rnn_tagger/variables/variables.index
deleted
100644 → 0
View file @
a4bb31d0
File deleted
research/syntaxnet/dragnn/runtime/testdata/ten_bytes
deleted
100644 → 0
View file @
a4bb31d0
0123456789
\ No newline at end of file
research/syntaxnet/dragnn/runtime/trained_model.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/trained_model.h"
#include <unordered_set>
#include "tensorflow/cc/saved_model/tag_constants.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
tensorflow
::
Status
TrainedModel
::
Reset
(
const
string
&
saved_model_dir
)
{
const
std
::
unordered_set
<
string
>
tags
=
{
tensorflow
::
kSavedModelTagServe
};
tensorflow
::
SavedModelBundle
saved_model
;
TF_RETURN_IF_ERROR
(
tensorflow
::
LoadSavedModel
({},
{},
saved_model_dir
,
tags
,
&
saved_model
));
// Success; make modifications.
saved_model_
.
session
=
std
::
move
(
saved_model
.
session
);
saved_model_
.
meta_graph_def
=
std
::
move
(
saved_model
.
meta_graph_def
);
nodes_
.
clear
();
const
tensorflow
::
GraphDef
&
graph
=
saved_model_
.
meta_graph_def
.
graph_def
();
for
(
const
tensorflow
::
NodeDef
&
node
:
graph
.
node
())
{
nodes_
[
node
.
name
()]
=
&
node
;
}
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
TrainedModel
::
EvaluateTensor
(
const
string
&
name
,
tensorflow
::
Tensor
*
tensor
)
const
{
if
(
saved_model_
.
session
==
nullptr
)
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"TF Session is not active"
);
}
// For some reason, runtime hook nodes cannot be evaluated without feeding an
// input batch. An empty batch currently works, but if DRAGNN starts failing
// on empty batches, a reasonable alternative is a batch of empty strings.
const
string
input_name
=
"annotation/ComputeSession/InputBatch"
;
const
tensorflow
::
Tensor
empty_batch
(
tensorflow
::
DT_STRING
,
tensorflow
::
TensorShape
({
0
}));
// Evaluate the variable in the session.
std
::
vector
<
tensorflow
::
Tensor
>
outputs
;
tensorflow
::
Status
status
=
saved_model_
.
session
->
Run
(
{{
input_name
,
empty_batch
}},
{
name
},
{},
&
outputs
);
if
(
!
status
.
ok
())
{
// Attach some extra information to the session error.
return
tensorflow
::
Status
(
status
.
code
(),
tensorflow
::
strings
::
StrCat
(
"Failed to evaluate tensor '"
,
name
,
"': "
,
status
.
error_message
()));
}
if
(
outputs
.
size
()
!=
1
)
{
return
tensorflow
::
errors
::
Unknown
(
"Expected exactly one output, but got "
,
outputs
.
size
(),
" outputs"
);
}
*
tensor
=
outputs
[
0
];
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
TrainedModel
::
LookupNode
(
const
string
&
name
,
const
tensorflow
::
NodeDef
**
node
)
const
{
if
(
saved_model_
.
session
==
nullptr
)
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"TF Session is not active"
);
}
const
auto
it
=
nodes_
.
find
(
name
);
if
(
it
==
nodes_
.
end
())
{
return
tensorflow
::
errors
::
NotFound
(
"Unknown node: '"
,
name
,
"'"
);
}
*
node
=
it
->
second
;
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
TrainedModel
::
GraphDef
(
const
tensorflow
::
GraphDef
**
graph
)
const
{
if
(
saved_model_
.
session
==
nullptr
)
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"TF Session is not active"
);
}
*
graph
=
&
saved_model_
.
meta_graph_def
.
graph_def
();
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
TrainedModel
::
Close
()
{
if
(
saved_model_
.
session
==
nullptr
)
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"TF Session is not active"
);
}
tensorflow
::
Status
status
=
saved_model_
.
session
->
Close
();
saved_model_
.
session
.
reset
();
saved_model_
.
meta_graph_def
.
Clear
();
nodes_
.
clear
();
return
status
;
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/trained_model.h
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_TRAINED_MODEL_H_
#define DRAGNN_RUNTIME_TRAINED_MODEL_H_
#include <map>
#include <string>
#include "syntaxnet/base.h"
#include "tensorflow/cc/saved_model/loader.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// A trained DRAGNN model, which can be queried for nodes and tensors.
class
TrainedModel
{
public:
// Creates an uninitialized model; call Reset() before use.
TrainedModel
()
=
default
;
// Loads the TF SavedModel at the |saved_model_dir|, replacing the current
// model, if any. On error, returns non-OK and modifies nothing.
tensorflow
::
Status
Reset
(
const
string
&
saved_model_dir
);
// Evaluates the tensor with the |name| in the |session_| and sets |tensor| to
// the result. On error, returns non-OK and modifies nothing.
//
// NB: Tensors that are embedded inside a tf.while_loop() cannot be evaluated.
// Such evaluations fail with errors like "Retval[0] does not have value".
tensorflow
::
Status
EvaluateTensor
(
const
string
&
name
,
tensorflow
::
Tensor
*
tensor
)
const
;
// Finds the node with the |name| in the |graph_| and points the |node| at it.
// On error, returns non-OK and modifies nothing.
tensorflow
::
Status
LookupNode
(
const
string
&
name
,
const
tensorflow
::
NodeDef
**
node
)
const
;
// Points |graph| at the GraphDef for the current model. It is an error if
// there is no current model.
tensorflow
::
Status
GraphDef
(
const
tensorflow
::
GraphDef
**
graph
)
const
;
// Discards the current model. It is an error if there is no current model.
// On error, returns non-OK but still discards the model.
tensorflow
::
Status
Close
();
private:
// TF SavedModel that contains the trained DRAGNN model.
tensorflow
::
SavedModelBundle
saved_model_
;
// Nodes in the TF graph, indexed by name.
std
::
map
<
string
,
const
tensorflow
::
NodeDef
*>
nodes_
;
};
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_TRAINED_MODEL_H_
research/syntaxnet/dragnn/runtime/trained_model_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/trained_model.h"
#include <stddef.h>
#include <string>
#include "dragnn/core/test/generic.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Relative path to a saved model.
constexpr
char
kSavedModelDir
[]
=
"dragnn/runtime/testdata/rnn_tagger"
;
// A valid tensor name in the test model and its dimensions.
constexpr
char
kTensorName
[]
=
"tagger/weights_0/ExponentialMovingAverage"
;
constexpr
size_t
kTensorRows
=
160
;
constexpr
size_t
kTensorColumns
=
64
;
// Returns a valid saved model directory.
string
GetSavedModelDir
()
{
return
tensorflow
::
io
::
JoinPath
(
test
::
GetTestDataPrefix
(),
kSavedModelDir
);
}
// Tests that TrainedModel can initialize itself from a valid saved model,
// retrieve tensors and nodes, and close itself. This is done in one test to
// avoid multiple (expensive) saved model loads.
TEST
(
TrainedModelTest
,
ResetQueryAndClose
)
{
TrainedModel
trained_model
;
TF_ASSERT_OK
(
trained_model
.
Reset
(
GetSavedModelDir
()));
// Look up a valid tensor.
tensorflow
::
Tensor
tensor
;
TF_ASSERT_OK
(
trained_model
.
EvaluateTensor
(
kTensorName
,
&
tensor
));
ASSERT_EQ
(
tensor
.
dims
(),
2
);
EXPECT_EQ
(
tensor
.
dim_size
(
0
),
kTensorRows
);
EXPECT_EQ
(
tensor
.
dim_size
(
1
),
kTensorColumns
);
// Look up an invalid tensor.
EXPECT_FALSE
(
trained_model
.
EvaluateTensor
(
"invalid"
,
&
tensor
).
ok
());
// Still have the old tensor contents.
ASSERT_EQ
(
tensor
.
dims
(),
2
);
EXPECT_EQ
(
tensor
.
dim_size
(
0
),
kTensorRows
);
EXPECT_EQ
(
tensor
.
dim_size
(
1
),
kTensorColumns
);
// Look up a valid node. Note that the tensor name doubles as a node name.
const
tensorflow
::
NodeDef
*
node
=
nullptr
;
TF_ASSERT_OK
(
trained_model
.
LookupNode
(
kTensorName
,
&
node
));
ASSERT_NE
(
node
,
nullptr
);
EXPECT_EQ
(
node
->
name
(),
kTensorName
);
// Look up an invalid node.
ASSERT_THAT
(
trained_model
.
LookupNode
(
"invalid"
,
&
node
),
test
::
IsErrorWithSubstr
(
"Unknown node"
));
// Still have the old node.
ASSERT_NE
(
node
,
nullptr
);
EXPECT_EQ
(
node
->
name
(),
kTensorName
);
// Get the current Graph.
const
tensorflow
::
GraphDef
*
graph_def
=
nullptr
;
TF_ASSERT_OK
(
trained_model
.
GraphDef
(
&
graph_def
));
EXPECT_GT
(
graph_def
->
node_size
(),
0
);
// First Close() is OK, second fails because already closed.
TF_EXPECT_OK
(
trained_model
.
Close
());
EXPECT_THAT
(
trained_model
.
Close
(),
test
::
IsErrorWithSubstr
(
"TF Session is not active"
));
}
// Tests that TrainedModel::Reset() fails on an invalid path.
TEST
(
TrainedModelTest
,
InvalidPath
)
{
TrainedModel
trained_model
;
EXPECT_FALSE
(
trained_model
.
Reset
(
"invalid/path"
).
ok
());
}
// Tests that TrainedModel::Close() fails if there is no model.
TEST
(
TrainedModelTest
,
CloseFailsBeforeReset
)
{
TrainedModel
trained_model
;
EXPECT_THAT
(
trained_model
.
Close
(),
test
::
IsErrorWithSubstr
(
"TF Session is not active"
));
}
// Tests that TrainedModel::GraphDef() fails if there is no active session.
TEST
(
TrainedModelTest
,
GraphDefFailsBeforeReset
)
{
const
tensorflow
::
GraphDef
*
graph_def
=
nullptr
;
TrainedModel
trained_model
;
EXPECT_THAT
(
trained_model
.
GraphDef
(
&
graph_def
),
test
::
IsErrorWithSubstr
(
"TF Session is not active"
));
}
// Tests that TrainedModel::EvaluateTensor() fails if there is no model.
TEST
(
TrainedModelTest
,
EvaluateTensorFailsBeforeReset
)
{
TrainedModel
trained_model
;
tensorflow
::
Tensor
tensor
;
EXPECT_THAT
(
trained_model
.
EvaluateTensor
(
"whatever"
,
&
tensor
),
test
::
IsErrorWithSubstr
(
"TF Session is not active"
));
}
// Tests that TrainedModel::LookupNode() fails if there is no model.
TEST
(
TrainedModelTest
,
LookupNodeFailsBeforeReset
)
{
TrainedModel
trained_model
;
const
tensorflow
::
NodeDef
*
node
=
nullptr
;
EXPECT_THAT
(
trained_model
.
LookupNode
(
"whatever"
,
&
node
),
test
::
IsErrorWithSubstr
(
"TF Session is not active"
));
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/trained_model_variable_store.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/trained_model_variable_store.h"
#include "dragnn/runtime/math/types.h"
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/errors.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
tensorflow
::
Status
TrainedModelVariableStore
::
Reset
(
const
string
&
saved_model_dir
)
{
TF_RETURN_IF_ERROR
(
trained_model_
.
Reset
(
saved_model_dir
));
// Success; make modifications.
variables_
.
clear
();
return
tensorflow
::
Status
::
OK
();
}
namespace
{
// Copies flat data from the |tensor|, casted to T, into the |array| and points
// the |area| at it. On error, returns non-OK.
template
<
class
T
>
tensorflow
::
Status
ExtractFlat
(
const
tensorflow
::
Tensor
&
tensor
,
std
::
vector
<
size_t
>
*
dimensions
,
UniqueAlignedArray
*
array
,
MutableAlignedArea
*
area
)
{
const
auto
flat
=
tensor
.
flat
<
T
>
();
const
size_t
bytes
=
flat
.
size
()
*
sizeof
(
T
);
array
->
Reset
(
ComputeAlignedAreaSize
(
1
,
bytes
));
TF_RETURN_IF_ERROR
(
area
->
Reset
(
array
->
view
(),
1
,
bytes
));
const
MutableVector
<
T
>
row
(
area
->
view
(
0
));
for
(
size_t
i
=
0
;
i
<
flat
.
size
();
++
i
)
row
[
i
]
=
flat
(
i
);
dimensions
->
clear
();
dimensions
->
push_back
(
flat
.
size
());
return
tensorflow
::
Status
::
OK
();
}
// Copies the |tensor|, casted to T and reshaped as a matrix, into the |array|
// and points the |area| at it. Requires that the |tensor| is rank 2 or more.
// On error, returns non-OK.
template
<
class
T
>
tensorflow
::
Status
ExtractMatrix
(
const
tensorflow
::
Tensor
&
tensor
,
std
::
vector
<
size_t
>
*
dimensions
,
UniqueAlignedArray
*
array
,
MutableAlignedArea
*
area
)
{
if
(
tensor
.
dims
()
<
2
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Tensor must be rank >= 2 but is rank "
,
tensor
.
dims
());
}
// Flatten all dims except the inner-most, creating a matrix.
const
auto
reshaped
=
tensor
.
flat_inner_dims
<
T
>
();
const
size_t
num_rows
=
reshaped
.
dimension
(
0
);
const
size_t
num_columns
=
reshaped
.
dimension
(
1
);
*
dimensions
=
{
num_rows
,
num_columns
};
const
size_t
view_size_bytes
=
num_columns
*
sizeof
(
T
);
array
->
Reset
(
ComputeAlignedAreaSize
(
num_rows
,
view_size_bytes
));
TF_RETURN_IF_ERROR
(
area
->
Reset
(
array
->
view
(),
num_rows
,
view_size_bytes
));
MutableMatrix
<
T
>
matrix
(
*
area
);
for
(
size_t
row
=
0
;
row
<
num_rows
;
++
row
)
{
for
(
size_t
column
=
0
;
column
<
num_columns
;
++
column
)
{
matrix
.
row
(
row
)[
column
]
=
reshaped
(
row
,
column
);
}
}
return
tensorflow
::
Status
::
OK
();
}
// Copies a blocked matrix from the |tensor|, casted to T, into the |array| and
// points the |area| at it. Requires that the |tensor| is rank 3. On error,
// returns non-OK.
template
<
class
T
>
tensorflow
::
Status
ExtractBlockedMatrix
(
const
tensorflow
::
Tensor
&
tensor
,
std
::
vector
<
size_t
>
*
dimensions
,
UniqueAlignedArray
*
array
,
MutableAlignedArea
*
area
)
{
if
(
tensor
.
dims
()
!=
3
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Tensor must be rank 3 but is rank "
,
tensor
.
dims
());
}
const
size_t
num_sub_matrices
=
tensor
.
dim_size
(
0
);
const
size_t
num_rows
=
tensor
.
dim_size
(
1
);
const
size_t
block_size
=
tensor
.
dim_size
(
2
);
const
size_t
num_columns
=
num_sub_matrices
*
block_size
;
*
dimensions
=
{
num_rows
,
num_columns
,
block_size
};
// Given the order of dimensions in the |tensor|, flattening it into a matrix
// via flat_inner_dims() and copying it to the |area| is equivalent to copying
// the blocked matrix.
std
::
vector
<
size_t
>
unused_dimensions
;
// ignore non-blocked dimensions
return
ExtractMatrix
<
T
>
(
tensor
,
&
unused_dimensions
,
array
,
area
);
}
}
// namespace
tensorflow
::
Status
TrainedModelVariableStore
::
Lookup
(
const
string
&
name
,
VariableSpec
::
Format
format
,
std
::
vector
<
size_t
>
*
dimensions
,
AlignedArea
*
area
)
{
const
Key
key
(
name
,
format
);
const
auto
it
=
variables_
.
find
(
key
);
if
(
it
!=
variables_
.
end
())
{
std
::
tie
(
std
::
ignore
,
*
dimensions
,
*
area
)
=
it
->
second
;
return
tensorflow
::
Status
::
OK
();
}
Variable
variable
;
TF_RETURN_IF_ERROR
(
GetVariableContents
(
name
,
format
,
&
variable
));
// Success; make modifications.
std
::
tie
(
std
::
ignore
,
*
dimensions
,
*
area
)
=
variable
;
variables_
[
key
]
=
std
::
move
(
variable
);
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
TrainedModelVariableStore
::
GetVariableContents
(
const
string
&
name
,
VariableSpec
::
Format
format
,
Variable
*
variable
)
{
tensorflow
::
Tensor
tensor
;
TF_RETURN_IF_ERROR
(
trained_model_
.
EvaluateTensor
(
name
,
&
tensor
));
// Extract typed tensor data.
UniqueAlignedArray
*
array
=
&
std
::
get
<
0
>
(
*
variable
);
std
::
vector
<
size_t
>
*
dimensions
=
&
std
::
get
<
1
>
(
*
variable
);
MutableAlignedArea
*
area
=
&
std
::
get
<
2
>
(
*
variable
);
if
(
tensor
.
dtype
()
==
tensorflow
::
DT_FLOAT
)
{
switch
(
format
)
{
case
VariableSpec
::
FORMAT_UNKNOWN
:
return
tensorflow
::
errors
::
InvalidArgument
(
"Unknown variable format"
);
case
VariableSpec
::
FORMAT_FLAT
:
return
ExtractFlat
<
float
>
(
tensor
,
dimensions
,
array
,
area
);
case
VariableSpec
::
FORMAT_ROW_MAJOR_MATRIX
:
return
ExtractMatrix
<
float
>
(
tensor
,
dimensions
,
array
,
area
);
case
VariableSpec
::
FORMAT_COLUMN_BLOCKED_ROW_MAJOR_MATRIX
:
return
ExtractBlockedMatrix
<
float
>
(
tensor
,
dimensions
,
array
,
area
);
}
}
else
if
(
tensor
.
dtype
()
==
tensorflow
::
DT_BFLOAT16
)
{
switch
(
format
)
{
case
VariableSpec
::
FORMAT_UNKNOWN
:
return
tensorflow
::
errors
::
InvalidArgument
(
"Unknown variable format"
);
case
VariableSpec
::
FORMAT_FLAT
:
return
ExtractFlat
<
tensorflow
::
bfloat16
>
(
tensor
,
dimensions
,
array
,
area
);
case
VariableSpec
::
FORMAT_ROW_MAJOR_MATRIX
:
return
ExtractMatrix
<
tensorflow
::
bfloat16
>
(
tensor
,
dimensions
,
array
,
area
);
case
VariableSpec
::
FORMAT_COLUMN_BLOCKED_ROW_MAJOR_MATRIX
:
return
ExtractBlockedMatrix
<
tensorflow
::
bfloat16
>
(
tensor
,
dimensions
,
array
,
area
);
}
}
else
{
// TODO(googleuser): Add clauses for additional types as needed.
return
tensorflow
::
errors
::
Unimplemented
(
"Data type not supported: "
,
tensorflow
::
DataType_Name
(
tensor
.
dtype
()));
}
}
tensorflow
::
Status
TrainedModelVariableStore
::
Close
()
{
return
trained_model_
.
Close
();
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/trained_model_variable_store.h
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_TRAINED_MODEL_VARIABLE_STORE_H_
#define DRAGNN_RUNTIME_TRAINED_MODEL_VARIABLE_STORE_H_
#include <map>
#include <memory>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include "dragnn/protos/runtime.pb.h"
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/trained_model.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// A variable store that extracts variables from a trained DRAGNN model. This
// should not be used in production (where ArrayVariableStore and its subclasses
// should be used), though it is convenient for experimentation.
class
TrainedModelVariableStore
:
public
VariableStore
{
public:
// Creates an uninitialized store.
TrainedModelVariableStore
()
=
default
;
// Resets this to represent the variables defined by the TF saved model at the
// |saved_model_dir|. On error, returns non-OK and modifies nothing.
tensorflow
::
Status
Reset
(
const
string
&
saved_model_dir
);
// Implements VariableStore.
using
VariableStore
::
Lookup
;
// import Lookup<T>() convenience methods
tensorflow
::
Status
Lookup
(
const
string
&
name
,
VariableSpec
::
Format
format
,
std
::
vector
<
size_t
>
*
dimensions
,
AlignedArea
*
area
)
override
;
tensorflow
::
Status
Close
()
override
;
private:
// A (name,format) key associated with a variable.
using
Key
=
std
::
pair
<
string
,
VariableSpec
::
Format
>
;
// Extracted and formatted variable contents, as an aligned byte array and an
// area that provides a structured interpretation.
using
Variable
=
std
::
tuple
<
UniqueAlignedArray
,
std
::
vector
<
size_t
>
,
MutableAlignedArea
>
;
// Extracts the contents of the variable named |name| in the |format| and
// stores the result in the |variable|. On error, returns non-OK.
tensorflow
::
Status
GetVariableContents
(
const
string
&
name
,
VariableSpec
::
Format
format
,
Variable
*
variable
);
// Trained DRAGNN model used to extract variables.
TrainedModel
trained_model_
;
// The already-extracted variables.
std
::
map
<
Key
,
Variable
>
variables_
;
};
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_TRAINED_MODEL_VARIABLE_STORE_H_
research/syntaxnet/dragnn/runtime/trained_model_variable_store_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/trained_model_variable_store.h"
#include <stddef.h>
#include <string>
#include <vector>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/runtime.pb.h"
#include "dragnn/runtime/math/avx_vector_array.h"
#include "dragnn/runtime/math/float16_types.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
class
TrainedModelVariableStoreTest
:
public
::
testing
::
Test
{
protected:
// Computes a value that accesses all bytes in the |view| or |area|. Useful
// for checking that a piece of memory is accessible.
size_t
SumBytes
(
AlignedView
view
)
{
size_t
sum
=
0
;
for
(
size_t
i
=
0
;
i
<
view
.
size
();
++
i
)
sum
+=
view
.
data
()[
i
];
return
sum
;
}
size_t
SumBytes
(
AlignedArea
area
)
{
size_t
sum
=
0
;
for
(
size_t
i
=
0
;
i
<
area
.
num_views
();
++
i
)
sum
+=
SumBytes
(
area
.
view
(
i
));
return
sum
;
}
// Returns the name of a tensor containing the blocked version of
// |kVariableName|, with the given |block_size|.
string
GetBlockedVariableName
(
int
block_size
)
const
{
return
tensorflow
::
strings
::
StrCat
(
kVariableNamePrefix
,
"/matrix/blocked"
,
block_size
,
"/ExponentialMovingAverage"
);
}
// Same as above, but returns the name of the bfloat16 variable.
string
GetBfloat16VariableName
(
int
block_size
)
const
{
return
tensorflow
::
strings
::
StrCat
(
kVariableNamePrefix
,
"/matrix/blocked"
,
block_size
,
"/bfloat16/ExponentialMovingAverage"
);
}
// Path to a saved model file for tests. Expected to contain:
// * A tf.float32 variable named |kVariableName| with shape
// [|kVariableRows|, |kVariableColumns|].
// * A variable named |kUnsupportedTypeVariableName| whose type is not
// supported by the implementation.
// * A variable named |kLowRankVariableName| whose rank is < 2.
const
string
kSavedModelDir
=
tensorflow
::
io
::
JoinPath
(
test
::
GetTestDataPrefix
(),
"dragnn/runtime/testdata/rnn_tagger"
);
// A valid variable name in the test model and its dimensions.
const
string
kVariableNamePrefix
=
"tagger/weights_0"
;
const
string
kVariableName
=
tensorflow
::
strings
::
StrCat
(
kVariableNamePrefix
,
"/ExponentialMovingAverage"
);
const
size_t
kVariableRows
=
160
;
const
size_t
kVariableColumns
=
64
;
// A variable with unsupported type; this variable is tf.int32.
const
string
kUnsupportedTypeVariableName
=
"tagger/step"
;
// A variable whose rank is < 2; this is a scalar.
const
string
kLowRankVariableName
=
"tagger/bias_1"
;
// Variable store for tests.
TrainedModelVariableStore
store_
;
};
// Tests that TrainedModelVariableStore can be initialized from a valid model.
TEST_F
(
TrainedModelVariableStoreTest
,
ResetValid
)
{
TF_EXPECT_OK
(
store_
.
Reset
(
kSavedModelDir
));
}
// Tests that TrainedModelVariableStore fails on a valid directory that doesn't
// actually contain a TF saved model, but can be re-Reset() on valid files.
TEST_F
(
TrainedModelVariableStoreTest
,
ResetInvalidDirectoryThenValid
)
{
EXPECT_FALSE
(
store_
.
Reset
(
"/tmp"
).
ok
());
TF_EXPECT_OK
(
store_
.
Reset
(
kSavedModelDir
));
}
// Tests that TrainedModelVariableStore fails on a non-directory, but can be
// re-Reset() on valid files.
TEST_F
(
TrainedModelVariableStoreTest
,
ResetNotADirectoryThenValid
)
{
EXPECT_FALSE
(
store_
.
Reset
(
"/dev/null"
).
ok
());
TF_EXPECT_OK
(
store_
.
Reset
(
kSavedModelDir
));
}
// Tests that TrainedModelVariableStore fails with missing files node scope, but
// can be re-Reset() on valid files.
TEST_F
(
TrainedModelVariableStoreTest
,
ResetMissingDirectoryThenValid
)
{
EXPECT_FALSE
(
store_
.
Reset
(
"/missing/model/dir"
).
ok
());
TF_EXPECT_OK
(
store_
.
Reset
(
kSavedModelDir
));
}
// Tests that TrainedModelVariableStore can only be closed once, and only after
// it is has been initialized.
TEST_F
(
TrainedModelVariableStoreTest
,
Close
)
{
EXPECT_THAT
(
store_
.
Close
(),
test
::
IsErrorWithSubstr
(
"TF Session is not active"
));
TF_ASSERT_OK
(
store_
.
Reset
(
kSavedModelDir
));
TF_EXPECT_OK
(
store_
.
Close
());
EXPECT_THAT
(
store_
.
Close
(),
test
::
IsErrorWithSubstr
(
"TF Session is not active"
));
}
// Tests that TrainedModelVariableStore can look up flat variables.
TEST_F
(
TrainedModelVariableStoreTest
,
LookupFlat
)
{
AlignedArea
area
;
std
::
vector
<
size_t
>
dimensions
;
// Fail to look up a valid name before initialization.
EXPECT_THAT
(
store_
.
Lookup
(
kVariableName
,
VariableSpec
::
FORMAT_FLAT
,
&
dimensions
,
&
area
),
test
::
IsErrorWithSubstr
(
"TF Session is not active"
));
EXPECT_TRUE
(
area
.
empty
());
// not modified
// Repeating the failed lookup should still fail.
EXPECT_THAT
(
store_
.
Lookup
(
kVariableName
,
VariableSpec
::
FORMAT_FLAT
,
&
dimensions
,
&
area
),
test
::
IsErrorWithSubstr
(
"TF Session is not active"
));
EXPECT_TRUE
(
area
.
empty
());
// not modified
// Fail to look up an invalid name after initialization.
TF_ASSERT_OK
(
store_
.
Reset
(
kSavedModelDir
));
EXPECT_FALSE
(
store_
.
Lookup
(
"invalid/name"
,
VariableSpec
::
FORMAT_FLAT
,
&
dimensions
,
&
area
)
.
ok
());
EXPECT_TRUE
(
area
.
empty
());
// not modified
// Successfully look up a valid name.
TF_ASSERT_OK
(
store_
.
Lookup
(
kVariableName
,
VariableSpec
::
FORMAT_FLAT
,
&
dimensions
,
&
area
));
EXPECT_FALSE
(
area
.
empty
());
// modified
EXPECT_EQ
(
area
.
num_views
(),
1
);
EXPECT_EQ
(
area
.
view_size
(),
kVariableRows
*
kVariableColumns
*
sizeof
(
float
));
// Try looking up the same name again.
area
=
AlignedArea
();
TF_ASSERT_OK
(
store_
.
Lookup
(
kVariableName
,
VariableSpec
::
FORMAT_FLAT
,
&
dimensions
,
&
area
));
EXPECT_EQ
(
area
.
num_views
(),
1
);
EXPECT_EQ
(
area
.
view_size
(),
kVariableRows
*
kVariableColumns
*
sizeof
(
float
));
// Check that the area can be accessed even after the |store| is closed.
TF_EXPECT_OK
(
store_
.
Close
());
LOG
(
INFO
)
<<
"Logging to prevent elision by optimizer: "
<<
SumBytes
(
area
);
}
// Tests that TrainedModelVariableStore can look up row-major matrix variables.
TEST_F
(
TrainedModelVariableStoreTest
,
LookupRowMajorMatrix
)
{
AlignedArea
area
;
std
::
vector
<
size_t
>
dimensions
;
// Fail to look up a valid name before initialization.
EXPECT_THAT
(
store_
.
Lookup
(
kVariableName
,
VariableSpec
::
FORMAT_ROW_MAJOR_MATRIX
,
&
dimensions
,
&
area
),
test
::
IsErrorWithSubstr
(
"TF Session is not active"
));
EXPECT_TRUE
(
area
.
empty
());
// not modified
// Repeating the failed lookup should still fail.
EXPECT_THAT
(
store_
.
Lookup
(
kVariableName
,
VariableSpec
::
FORMAT_ROW_MAJOR_MATRIX
,
&
dimensions
,
&
area
),
test
::
IsErrorWithSubstr
(
"TF Session is not active"
));
EXPECT_TRUE
(
area
.
empty
());
// not modified
// Fail to look up an invalid name after initialization.
TF_ASSERT_OK
(
store_
.
Reset
(
kSavedModelDir
));
EXPECT_FALSE
(
store_
.
Lookup
(
"invalid/name"
,
VariableSpec
::
FORMAT_ROW_MAJOR_MATRIX
,
&
dimensions
,
&
area
)
.
ok
());
EXPECT_TRUE
(
area
.
empty
());
// not modified
// Successfully look up a valid name.
TF_ASSERT_OK
(
store_
.
Lookup
(
kVariableName
,
VariableSpec
::
FORMAT_ROW_MAJOR_MATRIX
,
&
dimensions
,
&
area
));
ASSERT_FALSE
(
area
.
empty
());
// modified
EXPECT_EQ
(
dimensions
,
std
::
vector
<
size_t
>
({
kVariableRows
,
kVariableColumns
}));
EXPECT_EQ
(
area
.
num_views
(),
kVariableRows
);
EXPECT_EQ
(
area
.
view_size
(),
kVariableColumns
*
sizeof
(
float
));
// Try looking up the same name again.
area
=
AlignedArea
();
dimensions
.
clear
();
TF_ASSERT_OK
(
store_
.
Lookup
(
kVariableName
,
VariableSpec
::
FORMAT_ROW_MAJOR_MATRIX
,
&
dimensions
,
&
area
));
EXPECT_EQ
(
dimensions
,
std
::
vector
<
size_t
>
({
kVariableRows
,
kVariableColumns
}));
EXPECT_EQ
(
area
.
num_views
(),
kVariableRows
);
EXPECT_EQ
(
area
.
view_size
(),
kVariableColumns
*
sizeof
(
float
));
// Check that the area can be accessed even after the |store| is closed.
TF_EXPECT_OK
(
store_
.
Close
());
LOG
(
INFO
)
<<
"Logging to prevent elision by optimizer: "
<<
SumBytes
(
area
);
}
// Tests that the same contents can be retrieved in various formats, and that
// the content is the same asides from rearrangement.
TEST_F
(
TrainedModelVariableStoreTest
,
CompareFormats
)
{
Vector
<
float
>
flat
;
Matrix
<
float
>
row_major_matrix
;
TF_ASSERT_OK
(
store_
.
Reset
(
kSavedModelDir
));
TF_ASSERT_OK
(
store_
.
Lookup
(
kVariableName
,
&
flat
));
TF_ASSERT_OK
(
store_
.
Lookup
(
kVariableName
,
&
row_major_matrix
));
ASSERT_EQ
(
flat
.
size
(),
row_major_matrix
.
num_rows
()
*
row_major_matrix
.
num_columns
());
for
(
size_t
flat_index
=
0
,
row
=
0
;
row
<
row_major_matrix
.
num_rows
();
++
row
)
{
for
(
size_t
column
=
0
;
column
<
row_major_matrix
.
num_columns
();
++
column
,
++
flat_index
)
{
EXPECT_EQ
(
row_major_matrix
.
row
(
row
)[
column
],
flat
[
flat_index
]);
}
}
}
// Tests that TrainedModelVariableStore fails to retrieve a variable of an
// unsupported type.
TEST_F
(
TrainedModelVariableStoreTest
,
LookupUnsupportedType
)
{
AlignedArea
area
;
std
::
vector
<
size_t
>
dimensions
;
TF_ASSERT_OK
(
store_
.
Reset
(
kSavedModelDir
));
EXPECT_THAT
(
store_
.
Lookup
(
kUnsupportedTypeVariableName
,
VariableSpec
::
FORMAT_FLAT
,
&
dimensions
,
&
area
),
test
::
IsErrorWithSubstr
(
"Data type not supported"
));
}
// Tests that TrainedModelVariableStore fails to retrieve a variable of an
// unsupported type.
TEST_F
(
TrainedModelVariableStoreTest
,
LookupUnknownFormat
)
{
AlignedArea
area
;
std
::
vector
<
size_t
>
dimensions
;
TF_ASSERT_OK
(
store_
.
Reset
(
kSavedModelDir
));
EXPECT_THAT
(
store_
.
Lookup
(
kVariableName
,
VariableSpec
::
FORMAT_UNKNOWN
,
&
dimensions
,
&
area
),
test
::
IsErrorWithSubstr
(
"Unknown variable format"
));
}
// Tests that TrainedModelVariableStore fails to look up a variable without
// sufficient structure as an matrix.
TEST_F
(
TrainedModelVariableStoreTest
,
LookupInsufficientRank
)
{
AlignedArea
area
;
std
::
vector
<
size_t
>
dimensions
;
TF_ASSERT_OK
(
store_
.
Reset
(
kSavedModelDir
));
EXPECT_THAT
(
store_
.
Lookup
(
kLowRankVariableName
,
VariableSpec
::
FORMAT_ROW_MAJOR_MATRIX
,
&
dimensions
,
&
area
),
test
::
IsErrorWithSubstr
(
"Tensor must be rank >= 2"
));
}
// Tests that TrainedModelVariableStore produces column-blocked row-major
// matrices with the same content as the non-blocked version. Checks that
// bfloat16 matrices are a permuted version of blocked matrices.
TEST_F
(
TrainedModelVariableStoreTest
,
ColumnBlockedComparison
)
{
const
int
kBlockSize
=
32
;
const
string
kBlockedVariableName
=
GetBlockedVariableName
(
kBlockSize
);
const
string
kBfloat16VariableName
=
GetBfloat16VariableName
(
kBlockSize
);
Matrix
<
float
>
plain_matrix
;
BlockedMatrix
<
float
>
matrix
;
BlockedMatrix
<
TruncatedFloat16
>
bfloat16_matrix
;
TF_ASSERT_OK
(
store_
.
Reset
(
kSavedModelDir
));
TF_ASSERT_OK
(
store_
.
Lookup
(
kVariableName
,
&
plain_matrix
));
TF_ASSERT_OK
(
store_
.
Lookup
(
kBlockedVariableName
,
&
matrix
));
TF_ASSERT_OK
(
store_
.
Lookup
(
kBfloat16VariableName
,
&
bfloat16_matrix
));
ASSERT_EQ
(
matrix
.
num_rows
(),
kVariableRows
);
ASSERT_EQ
(
matrix
.
num_columns
(),
kVariableColumns
);
ASSERT_EQ
(
matrix
.
block_size
(),
kBlockSize
);
// Compare the content of the plain matrix with the blocked version.
for
(
int
column
=
0
;
column
<
matrix
.
num_columns
();
++
column
)
{
const
int
column_block_index
=
column
/
kBlockSize
;
const
int
index_in_block
=
column
%
kBlockSize
;
for
(
int
row
=
0
;
row
<
matrix
.
num_rows
();
++
row
)
{
const
int
block_index
=
column_block_index
*
matrix
.
num_rows
()
+
row
;
Vector
<
float
>
block
=
matrix
.
vector
(
block_index
);
EXPECT_EQ
(
block
[
index_in_block
],
plain_matrix
.
row
(
row
)[
column
]);
}
}
// Compare bfloat16-encoded values with float32 values.
ASSERT_EQ
(
matrix
.
num_vectors
(),
bfloat16_matrix
.
num_vectors
());
ASSERT_EQ
(
matrix
.
block_size
(),
bfloat16_matrix
.
block_size
());
ASSERT_EQ
(
matrix
.
num_rows
(),
bfloat16_matrix
.
num_rows
());
ASSERT_EQ
(
matrix
.
num_columns
(),
bfloat16_matrix
.
num_columns
());
for
(
int
vector
=
0
;
vector
<
matrix
.
num_vectors
();
++
vector
)
{
const
auto
&
matrix_vector
=
matrix
.
vector
(
vector
);
const
auto
&
bfloat16_vector
=
bfloat16_matrix
.
vector
(
vector
);
for
(
int
i
=
0
;
i
<
matrix
.
block_size
();
++
i
)
{
int
permuted
=
FastUnpackPermutation
(
i
);
const
float
matrix_value
=
matrix_vector
[
i
];
const
float
bfloat16_value
=
bfloat16_vector
[
permuted
].
DebugToFloat
();
EXPECT_NEAR
(
matrix_value
,
bfloat16_value
,
5e-3
);
}
}
}
// Tests that TrainedModelVariableStore overwrites the dimension vector passed
// to Lookup().
TEST_F
(
TrainedModelVariableStoreTest
,
OverwritesDimensions
)
{
const
int
kBlockSize
=
32
;
const
string
kBlockedVariableName
=
GetBlockedVariableName
(
kBlockSize
);
TF_ASSERT_OK
(
store_
.
Reset
(
kSavedModelDir
));
std
::
vector
<
VariableSpec
::
Format
>
formats
{
VariableSpec
::
FORMAT_FLAT
,
VariableSpec
::
FORMAT_ROW_MAJOR_MATRIX
,
VariableSpec
::
FORMAT_COLUMN_BLOCKED_ROW_MAJOR_MATRIX
};
for
(
const
auto
&
format
:
formats
)
{
std
::
vector
<
size_t
>
dimensions
;
dimensions
.
push_back
(
1234
);
AlignedArea
area
;
TF_ASSERT_OK
(
store_
.
Lookup
(
kBlockedVariableName
,
format
,
&
dimensions
,
&
area
));
EXPECT_NE
(
dimensions
[
0
],
1234
);
std
::
vector
<
size_t
>
expected_dimensions
;
switch
(
format
)
{
case
VariableSpec
::
FORMAT_UNKNOWN
:
LOG
(
FATAL
)
<<
"Invalid format"
;
case
VariableSpec
::
FORMAT_FLAT
:
expected_dimensions
=
{
kVariableRows
*
kVariableColumns
};
break
;
case
VariableSpec
::
FORMAT_ROW_MAJOR_MATRIX
:
// NB: We're fetching the rank-3 "/matrix/blockedNN" version and then
// reshaping into a matrix, so the dimensions are not the same as the
// plain matrix.
expected_dimensions
=
{
kVariableRows
*
kVariableColumns
/
kBlockSize
,
kBlockSize
};
break
;
case
VariableSpec
::
FORMAT_COLUMN_BLOCKED_ROW_MAJOR_MATRIX
:
expected_dimensions
=
{
kVariableRows
,
kVariableColumns
,
kBlockSize
};
break
;
}
EXPECT_EQ
(
dimensions
,
expected_dimensions
);
}
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/transition_system_traits.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/transition_system_traits.h"
#include <string>
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/strings/str_util.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Note: The traits are currently simple enough to specify in one file. We can
// also use a registry-based system if this gets too complex.
// Returns true if the |component_spec| is deterministic.
bool
IsDeterministic
(
const
ComponentSpec
&
component_spec
)
{
return
component_spec
.
num_actions
()
==
1
;
}
// Returns true if the |component_spec| is sequential.
bool
IsSequential
(
const
ComponentSpec
&
component_spec
)
{
const
string
&
name
=
component_spec
.
transition_system
().
registered_name
();
return
name
==
"char-shift-only"
||
//
name
==
"shift-only"
||
//
name
==
"tagger"
||
//
name
==
"morpher"
||
//
name
==
"heads"
||
//
name
==
"labels"
;
}
// Returns true if the |component_spec| specifies a left-to-right transition
// system. The default when unspecified is true.
bool
IsLeftToRight
(
const
ComponentSpec
&
component_spec
)
{
const
auto
&
parameters
=
component_spec
.
transition_system
().
parameters
();
const
auto
it
=
parameters
.
find
(
"left_to_right"
);
if
(
it
==
parameters
.
end
())
return
true
;
return
tensorflow
::
str_util
::
Lowercase
(
it
->
second
)
!=
"false"
;
}
// Returns true if the |transition_system| is character-scale.
bool
IsCharacterScale
(
const
ComponentSpec
&
component_spec
)
{
const
string
&
name
=
component_spec
.
transition_system
().
registered_name
();
return
//
name
==
"char-shift-only"
;
}
// Returns true if the |transition_system| is token-scale.
bool
IsTokenScale
(
const
ComponentSpec
&
component_spec
)
{
const
string
&
name
=
component_spec
.
transition_system
().
registered_name
();
return
name
==
"shift-only"
||
//
name
==
"tagger"
||
//
name
==
"morpher"
||
//
name
==
"heads"
||
//
name
==
"labels"
;
}
}
// namespace
TransitionSystemTraits
::
TransitionSystemTraits
(
const
ComponentSpec
&
component_spec
)
:
is_deterministic
(
IsDeterministic
(
component_spec
)),
is_sequential
(
IsSequential
(
component_spec
)),
is_left_to_right
(
IsLeftToRight
(
component_spec
)),
is_character_scale
(
IsCharacterScale
(
component_spec
)),
is_token_scale
(
IsTokenScale
(
component_spec
))
{}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/transition_system_traits.h
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_TRANSITION_SYSTEM_TRAITS_H_
#define DRAGNN_RUNTIME_TRANSITION_SYSTEM_TRAITS_H_
#include "dragnn/protos/spec.pb.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// Traits describing the transition system used by some component.
struct
TransitionSystemTraits
{
// Creates a set of traits describing the |component_spec|.
explicit
TransitionSystemTraits
(
const
ComponentSpec
&
component_spec
);
// Whether the transition system is deterministic---i.e., it can be advanced
// without computing logits and making predictions.
const
bool
is_deterministic
;
// Whether the transition system is sequential---i.e., compatible with
// SequenceBackend, SequenceExtractor, and so on.
const
bool
is_sequential
;
// Whether the transition system advances from left to right in the underlying
// input sequence. This only makes sense if |sequential| is true.
const
bool
is_left_to_right
;
// Whether the transition steps correspond to characters or tokens. This only
// makes sense if |sequential| is true.
//
// TODO(googleuser): Distinguish between full-text character transition systems
// and per-word ones?
const
bool
is_character_scale
;
const
bool
is_token_scale
;
};
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_TRANSITION_SYSTEM_TRAITS_H_
research/syntaxnet/dragnn/runtime/transition_system_traits_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/transition_system_traits.h"
#include <string>
#include <utility>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Returns a ComponentSpec that uses the |transition_system|, is configured to
// run left-to-right if |left_to_right| is true, and whose transition system
// predicts |num_actions| actions.
ComponentSpec
MakeTestSpec
(
const
string
&
transition_system
,
bool
left_to_right
,
int
num_actions
)
{
ComponentSpec
component_spec
;
component_spec
.
set_num_actions
(
num_actions
);
component_spec
.
mutable_transition_system
()
->
set_registered_name
(
transition_system
);
component_spec
.
mutable_transition_system
()
->
mutable_parameters
()
->
insert
(
{
"left_to_right"
,
left_to_right
?
"true"
:
"false"
});
return
component_spec
;
}
// Tests that boolean values are case-insensitive.
TEST
(
TransitionSystemTraitsAttributeParsingTest
,
CaseInsensitiveBooleanValues
)
{
ComponentSpec
component_spec
=
MakeTestSpec
(
"shift-only"
,
false
,
1
);
auto
&
parameters
=
*
component_spec
.
mutable_transition_system
()
->
mutable_parameters
();
for
(
const
string
&
true_value
:
{
"TRUE"
,
"True"
})
{
parameters
[
"left_to_right"
]
=
true_value
;
TransitionSystemTraits
traits
(
component_spec
);
EXPECT_TRUE
(
traits
.
is_left_to_right
);
}
for
(
const
string
&
false_value
:
{
"FALSE"
,
"False"
})
{
parameters
[
"left_to_right"
]
=
false_value
;
TransitionSystemTraits
traits
(
component_spec
);
EXPECT_FALSE
(
traits
.
is_left_to_right
);
}
}
// Parameterized on (left-to-right, deterministic).
class
TransitionSystemTraitsTest
:
public
::
testing
::
TestWithParam
<::
testing
::
tuple
<
bool
,
bool
>>
{
protected:
// Returns the test parameters.
bool
left_to_right
()
const
{
return
::
testing
::
get
<
0
>
(
GetParam
());
}
bool
deterministic
()
const
{
return
::
testing
::
get
<
1
>
(
GetParam
());
}
// Returns a ComponentSpec for the |transition_system|.
ComponentSpec
MakeSpec
(
const
string
&
transition_system
)
{
return
MakeTestSpec
(
transition_system
,
left_to_right
(),
deterministic
()
?
1
:
10
);
}
};
INSTANTIATE_TEST_CASE_P
(
LeftToRightAndDeterministic
,
TransitionSystemTraitsTest
,
::
testing
::
Combine
(
::
testing
::
Bool
(),
::
testing
::
Bool
()));
// Tests the traits of an unknown transition system.
TEST_P
(
TransitionSystemTraitsTest
,
Unknown
)
{
TransitionSystemTraits
traits
(
MakeSpec
(
"unknown"
));
EXPECT_EQ
(
traits
.
is_deterministic
,
deterministic
());
EXPECT_FALSE
(
traits
.
is_sequential
);
EXPECT_EQ
(
traits
.
is_left_to_right
,
left_to_right
());
EXPECT_FALSE
(
traits
.
is_character_scale
);
EXPECT_FALSE
(
traits
.
is_token_scale
);
}
// Tests the traits of the "char-shift-only" transition system.
TEST_P
(
TransitionSystemTraitsTest
,
CharShiftOnly
)
{
TransitionSystemTraits
traits
(
MakeSpec
(
"char-shift-only"
));
EXPECT_EQ
(
traits
.
is_deterministic
,
deterministic
());
EXPECT_TRUE
(
traits
.
is_sequential
);
EXPECT_EQ
(
traits
.
is_left_to_right
,
left_to_right
());
EXPECT_TRUE
(
traits
.
is_character_scale
);
EXPECT_FALSE
(
traits
.
is_token_scale
);
}
// Tests the traits of the "shift-only" transition system.
TEST_P
(
TransitionSystemTraitsTest
,
ShiftOnly
)
{
TransitionSystemTraits
traits
(
MakeSpec
(
"shift-only"
));
EXPECT_EQ
(
traits
.
is_deterministic
,
deterministic
());
EXPECT_TRUE
(
traits
.
is_sequential
);
EXPECT_EQ
(
traits
.
is_left_to_right
,
left_to_right
());
EXPECT_FALSE
(
traits
.
is_character_scale
);
EXPECT_TRUE
(
traits
.
is_token_scale
);
}
// Tests the traits of the "tagger" transition system.
TEST_P
(
TransitionSystemTraitsTest
,
Tagger
)
{
TransitionSystemTraits
traits
(
MakeSpec
(
"tagger"
));
EXPECT_EQ
(
traits
.
is_deterministic
,
deterministic
());
EXPECT_TRUE
(
traits
.
is_sequential
);
EXPECT_EQ
(
traits
.
is_left_to_right
,
left_to_right
());
EXPECT_FALSE
(
traits
.
is_character_scale
);
EXPECT_TRUE
(
traits
.
is_token_scale
);
}
// Tests the traits of the "morpher" transition system.
TEST_P
(
TransitionSystemTraitsTest
,
Morpher
)
{
TransitionSystemTraits
traits
(
MakeSpec
(
"morpher"
));
EXPECT_EQ
(
traits
.
is_deterministic
,
deterministic
());
EXPECT_TRUE
(
traits
.
is_sequential
);
EXPECT_EQ
(
traits
.
is_left_to_right
,
left_to_right
());
EXPECT_FALSE
(
traits
.
is_character_scale
);
EXPECT_TRUE
(
traits
.
is_token_scale
);
}
// Tests the traits of the "heads" transition system.
TEST_P
(
TransitionSystemTraitsTest
,
Heads
)
{
TransitionSystemTraits
traits
(
MakeSpec
(
"heads"
));
EXPECT_EQ
(
traits
.
is_deterministic
,
deterministic
());
EXPECT_TRUE
(
traits
.
is_sequential
);
EXPECT_EQ
(
traits
.
is_left_to_right
,
left_to_right
());
EXPECT_FALSE
(
traits
.
is_character_scale
);
EXPECT_TRUE
(
traits
.
is_token_scale
);
}
// Tests the traits of the "labels" transition system.
TEST_P
(
TransitionSystemTraitsTest
,
Labels
)
{
TransitionSystemTraits
traits
(
MakeSpec
(
"labels"
));
EXPECT_EQ
(
traits
.
is_deterministic
,
deterministic
());
EXPECT_TRUE
(
traits
.
is_sequential
);
EXPECT_EQ
(
traits
.
is_left_to_right
,
left_to_right
());
EXPECT_FALSE
(
traits
.
is_character_scale
);
EXPECT_TRUE
(
traits
.
is_token_scale
);
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/type_keyed_set.h
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_TYPE_KEYED_SET_H_
#define DRAGNN_RUNTIME_TYPE_KEYED_SET_H_
#include <map>
#include <utility>
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// A heterogeneous set of type-keyed objects. Objects of any type can be added,
// but this can only hold at most one object of each type.
//
// Note that this class does not have any locking, so threads must externally
// coordinate to ensure that every instance of this set is only accessed by one
// thread at a time. When used via SessionState, these conditions are enforced
// by the runtime framework.
class
TypeKeyedSet
{
public:
// Creates an empty set.
TypeKeyedSet
()
=
default
;
// Moves all objects from |that| to this. Afterwards, the objects in this are
// address-equal to the objects originally in |that|.
TypeKeyedSet
(
TypeKeyedSet
&&
that
);
TypeKeyedSet
&
operator
=
(
TypeKeyedSet
&&
that
);
~
TypeKeyedSet
()
{
Clear
();
}
// Removes all objects from this set.
void
Clear
();
// Returns the T in this set, creating it first via T() if needed.
template
<
class
T
>
T
&
Get
();
private:
// Function that can delete an untyped pointer using the proper type.
using
Deleter
=
void
(
*
)(
void
*
);
// Deletes the |object| as a T. All Deleters point to this function.
template
<
class
T
>
static
void
DeleteAs
(
void
*
object
);
// Mapping from deleter to object. This owns the objects.
std
::
map
<
Deleter
,
void
*>
objects_
;
};
// Implementation details below.
inline
TypeKeyedSet
::
TypeKeyedSet
(
TypeKeyedSet
&&
that
)
:
objects_
(
std
::
move
(
that
.
objects_
))
{
that
.
objects_
.
clear
();
}
inline
TypeKeyedSet
&
TypeKeyedSet
::
operator
=
(
TypeKeyedSet
&&
that
)
{
Clear
();
objects_
=
std
::
move
(
that
.
objects_
);
that
.
objects_
.
clear
();
return
*
this
;
}
inline
void
TypeKeyedSet
::
Clear
()
{
for
(
const
auto
&
it
:
objects_
)
it
.
first
(
it
.
second
);
objects_
.
clear
();
}
template
<
class
T
>
T
&
TypeKeyedSet
::
Get
()
{
// Implementation notes:
// * DeleteAs<T>() is unique per T, so keying on its instantiation it is
// equivalent to keying on type, as desired.
// * The |object| pointer below is doubly-indirect: it is a reference to a
// void* pointer that lives in the |objects_| map.
// * If there was previously no entry in |objects_|, then |object| will be
// value-initialized (i.e., nulled), and we reassign it to a new T().
void
*&
object
=
objects_
[
&
DeleteAs
<
T
>
];
if
(
object
==
nullptr
)
object
=
new
T
();
return
*
reinterpret_cast
<
T
*>
(
object
);
}
template
<
class
T
>
void
TypeKeyedSet
::
DeleteAs
(
void
*
object
)
{
delete
reinterpret_cast
<
T
*>
(
object
);
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_TYPE_KEYED_SET_H_
research/syntaxnet/dragnn/runtime/type_keyed_set_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/type_keyed_set.h"
#include <utility>
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Dummy struct for tests.
struct
Foo
{
float
value
=
-
1.5
;
};
// Type aliases to exercise usage of aliases as type keys.
using
OtherInt
=
int
;
using
OtherFoo
=
Foo
;
// Tests that TypeKeyedSet::Get() returns the same object once created.
TEST
(
TypeKeyedSetTest
,
Get
)
{
TypeKeyedSet
set
;
// Get a couple types, and check for default-constructed values.
int
&
int_object
=
set
.
Get
<
int
>
();
ASSERT_NE
(
&
int_object
,
nullptr
);
EXPECT_EQ
(
int_object
,
0
);
// due to T()
int_object
=
2718
;
Foo
&
foo_object
=
set
.
Get
<
Foo
>
();
ASSERT_NE
(
&
foo_object
,
nullptr
);
EXPECT_EQ
(
foo_object
.
value
,
-
1.5
);
// due to T()
foo_object
.
value
=
3141.5
;
// Get the same types again, this time using type aliases, and check for
// address and value equality.
OtherInt
&
other_int_object
=
set
.
Get
<
OtherInt
>
();
EXPECT_EQ
(
&
other_int_object
,
&
int_object
);
EXPECT_EQ
(
other_int_object
,
2718
);
OtherFoo
&
other_foo_object
=
set
.
Get
<
OtherFoo
>
();
EXPECT_EQ
(
&
other_foo_object
,
&
foo_object
);
EXPECT_EQ
(
other_foo_object
.
value
,
3141.5
);
}
// Tests that TypeKeyedSet::Clear() removes existing values.
TEST
(
TypeKeyedSetTest
,
Clear
)
{
// Create a set with some values.
TypeKeyedSet
set
;
int
&
int_object
=
set
.
Get
<
int
>
();
int_object
=
2718
;
Foo
&
foo_object
=
set
.
Get
<
Foo
>
();
foo_object
.
value
=
3141.5
;
// Clear the set and check that the values are now defaulted.
set
.
Clear
();
EXPECT_EQ
(
set
.
Get
<
int
>
(),
0
);
EXPECT_EQ
(
set
.
Get
<
Foo
>
().
value
,
-
1.5
);
}
// Tests that TypeKeyedSet supports move construction.
TEST
(
TypeKeyedSetTest
,
MoveConstruction
)
{
TypeKeyedSet
set1
;
// Insert a couple of values.
int
&
int_object
=
set1
.
Get
<
int
>
();
int_object
=
2718
;
Foo
&
foo_object
=
set1
.
Get
<
Foo
>
();
foo_object
.
value
=
3141.5
;
// Move-construct another set, and check address and value equality.
TypeKeyedSet
set2
(
std
::
move
(
set1
));
OtherInt
&
other_int_object
=
set2
.
Get
<
OtherInt
>
();
EXPECT_EQ
(
&
other_int_object
,
&
int_object
);
EXPECT_EQ
(
other_int_object
,
2718
);
OtherFoo
&
other_foo_object
=
set2
.
Get
<
OtherFoo
>
();
EXPECT_EQ
(
&
other_foo_object
,
&
foo_object
);
EXPECT_EQ
(
other_foo_object
.
value
,
3141.5
);
}
// Tests that TypeKeyedSet supports move assignment.
TEST
(
TypeKeyedSetTest
,
MoveAssignment
)
{
// Create one set with some values.
TypeKeyedSet
set1
;
int
&
int_object
=
set1
.
Get
<
int
>
();
int_object
=
2718
;
Foo
&
foo_object
=
set1
.
Get
<
Foo
>
();
foo_object
.
value
=
3141.5
;
// Create another set with different values, to be overwritten.
TypeKeyedSet
set2
;
set2
.
Get
<
int
>
()
=
123
;
set2
.
Get
<
Foo
>
().
value
=
76.5
;
// Move-assign to another set, and check address and value equality.
set2
=
std
::
move
(
set1
);
OtherInt
&
other_int_object
=
set2
.
Get
<
OtherInt
>
();
EXPECT_EQ
(
&
other_int_object
,
&
int_object
);
EXPECT_EQ
(
other_int_object
,
2718
);
OtherFoo
&
other_foo_object
=
set2
.
Get
<
OtherFoo
>
();
EXPECT_EQ
(
&
other_foo_object
,
&
foo_object
);
EXPECT_EQ
(
other_foo_object
.
value
,
3141.5
);
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/unicode_dictionary.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/unicode_dictionary.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Returns a string representation of the byte sequence of the |character|.
string
CharacterDebugString
(
const
string
&
character
)
{
const
auto
*
bytes
=
reinterpret_cast
<
const
uint8
*>
(
character
.
data
());
string
debug
=
"["
;
for
(
int
i
=
0
;
i
<
character
.
size
();
++
i
)
{
tensorflow
::
strings
::
StrAppend
(
&
debug
,
i
==
0
?
""
:
" "
,
bytes
[
i
]);
}
tensorflow
::
strings
::
StrAppend
(
&
debug
,
"]"
);
return
debug
;
}
}
// namespace
UnicodeDictionary
::
UnicodeDictionary
()
{
Clear
();
}
UnicodeDictionary
::
UnicodeDictionary
(
const
string
&
character_map_path
,
int
min_frequency
,
int
max_num_terms
)
{
TF_CHECK_OK
(
Reset
(
TermFrequencyMap
(
character_map_path
,
min_frequency
,
max_num_terms
)));
}
void
UnicodeDictionary
::
Clear
()
{
size_
=
0
;
for
(
int32
&
index
:
single_byte_indices_
)
index
=
-
1
;
multi_byte_indices_
.
clear
();
}
tensorflow
::
Status
UnicodeDictionary
::
Reset
(
const
TermFrequencyMap
&
character_map
)
{
Clear
();
size_
=
character_map
.
Size
();
for
(
int32
index
=
0
;
index
<
character_map
.
Size
();
++
index
)
{
const
string
&
character
=
character_map
.
GetTerm
(
index
);
if
(
character
.
empty
())
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Term "
,
index
,
" is empty"
);
}
const
size_t
correct_size
=
UniLib
::
OneCharLen
(
character
.
data
());
if
(
character
.
size
()
!=
correct_size
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Term "
,
index
,
" should have size "
,
correct_size
,
": "
,
CharacterDebugString
(
character
));
}
if
(
!
UniLib
::
IsUTF8ValidCodepoint
(
character
))
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Term "
,
index
,
" is not valid UTF-8: "
,
CharacterDebugString
(
character
));
}
const
auto
*
bytes
=
reinterpret_cast
<
const
uint8
*>
(
character
.
data
());
if
(
character
.
size
()
==
1
)
{
DCHECK_EQ
(
single_byte_indices_
[
*
bytes
],
-
1
);
single_byte_indices_
[
*
bytes
]
=
index
;
}
else
{
const
uint32
key
=
MultiByteKey
(
bytes
,
character
.
size
());
DCHECK
(
multi_byte_indices_
.
find
(
key
)
==
multi_byte_indices_
.
end
());
multi_byte_indices_
[
key
]
=
index
;
}
}
return
tensorflow
::
Status
::
OK
();
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/unicode_dictionary.h
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_UNICODE_DICTIONARY_H_
#define DRAGNN_RUNTIME_UNICODE_DICTIONARY_H_
#include <stddef.h>
#include <unordered_map>
#include <string>
#include "syntaxnet/base.h"
#include "syntaxnet/term_frequency_map.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/macros.h"
#include "util/utf8/unilib.h"
#include "util/utf8/unilib_utf8_utils.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// A mapping from Unicode characters to indices.
//
// TODO(googleuser): Try integrating break chars into this mapping, maybe just for
// the ASCII break chars. They could be mapped directly to the break ID, so all
// one-byte characters are handled directly.
class
UnicodeDictionary
{
public:
// Creates an empty mapping.
UnicodeDictionary
();
// Loads a TermFrequencyMap from the |character_map_path| while applying the
// |min_frequency| and |max_num_terms|, and Reset()s this from it. On error,
// dies. This is for use in SharedStore; prefer Initialize() otherwise.
UnicodeDictionary
(
const
string
&
character_map_path
,
int
min_frequency
,
int
max_num_terms
);
// Resets this to the |character_map|. On error, returns non-OK.
tensorflow
::
Status
Reset
(
const
TermFrequencyMap
&
character_map
);
// Returns the index of the UTF-8 character spanning [|data|,|data|+|size|),
// or the |unknown_index| if not present in this.
int32
Lookup
(
const
char
*
data
,
size_t
size
,
int32
unknown_index
)
const
;
// Accessors.
size_t
size
()
const
{
return
size_
;
}
private:
// Removes all entries from this mapping.
void
Clear
();
// Returns an integer that uniquely identifies the multi-byte UTF-8 character
// spanning [|bytes|,|bytes|+|size|). Note that the returned value is not a
// Unicode codepoint.
static
uint32
MultiByteKey
(
const
uint8
*
bytes
,
size_t
size
);
// Number of entries in this mapping.
size_t
size_
=
0
;
// Dense mapping from single-byte UTF-8 (i.e., ASCII) character to index, or
// -1 if unmapped.
int32
single_byte_indices_
[
128
];
// Sparse mapping from multi-byte UTF-8 character to index.
std
::
unordered_map
<
uint32
,
int32
>
multi_byte_indices_
;
};
// Implementation details below.
inline
int32
UnicodeDictionary
::
Lookup
(
const
char
*
data
,
size_t
size
,
int32
unknown_index
)
const
{
DCHECK_GE
(
size
,
1
);
DCHECK_EQ
(
size
,
UniLib
::
OneCharLen
(
data
));
DCHECK
(
UniLib
::
IsUTF8ValidCodepoint
(
string
(
data
,
size
)));
const
auto
*
bytes
=
reinterpret_cast
<
const
uint8
*>
(
data
);
if
(
size
==
1
)
{
// Look up single-byte characters in the dense mapping.
DCHECK_LT
(
*
bytes
,
128
);
const
int32
index
=
single_byte_indices_
[
*
bytes
];
return
index
>=
0
?
index
:
unknown_index
;
}
else
{
// Look up multi-byte characters in the sparse mapping.
const
auto
it
=
multi_byte_indices_
.
find
(
MultiByteKey
(
bytes
,
size
));
return
it
!=
multi_byte_indices_
.
end
()
?
it
->
second
:
unknown_index
;
}
}
inline
uint32
UnicodeDictionary
::
MultiByteKey
(
const
uint8
*
bytes
,
size_t
size
)
{
DCHECK_GE
(
size
,
2
);
DCHECK_LE
(
size
,
4
);
uint32
value
=
static_cast
<
uint32
>
(
bytes
[
0
])
|
//
static_cast
<
uint32
>
(
bytes
[
1
])
<<
8
;
switch
(
size
)
{
case
4
:
value
|=
static_cast
<
uint32
>
(
bytes
[
3
])
<<
24
;
TF_FALLTHROUGH_INTENDED
;
case
3
:
value
|=
static_cast
<
uint32
>
(
bytes
[
2
])
<<
16
;
}
return
value
;
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_UNICODE_DICTIONARY_H_
research/syntaxnet/dragnn/runtime/unicode_dictionary_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/unicode_dictionary.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/runtime/test/term_map_helpers.h"
#include "syntaxnet/base.h"
#include "syntaxnet/term_frequency_map.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "third_party/utf/utf.h"
#include "util/utf8/unilib.h"
#include "util/utf8/unilib_utf8_utils.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
constexpr
char
kInvalidUtf8
[]
=
"
\xff\xff\xff\xff
"
;
constexpr
char
k1ByteCharacter
[]
=
"a"
;
constexpr
char
k2ByteCharacter
[]
=
"¼"
;
constexpr
char
k3ByteCharacter
[]
=
"好"
;
constexpr
char
k4ByteCharacter
[]
=
"𠜎"
;
// NB: String sizes are one more than expected from the trailing NUL.
static_assert
(
sizeof
(
k1ByteCharacter
)
/
sizeof
(
char
)
==
2
,
"1-byte character has the wrong size"
);
static_assert
(
sizeof
(
k2ByteCharacter
)
/
sizeof
(
char
)
==
3
,
"2-byte character has the wrong size"
);
static_assert
(
sizeof
(
k3ByteCharacter
)
/
sizeof
(
char
)
==
4
,
"3-byte character has the wrong size"
);
static_assert
(
sizeof
(
k4ByteCharacter
)
/
sizeof
(
char
)
==
5
,
"4-byte character has the wrong size"
);
// Tests that the dictionary is empty by default.
TEST
(
UnicodeDictionaryTest
,
EmptyByDefault
)
{
UnicodeDictionary
dictionary
;
EXPECT_EQ
(
dictionary
.
size
(),
0
);
EXPECT_EQ
(
dictionary
.
Lookup
(
k1ByteCharacter
,
1
,
-
123
),
-
123
);
EXPECT_EQ
(
dictionary
.
Lookup
(
k2ByteCharacter
,
2
,
-
123
),
-
123
);
EXPECT_EQ
(
dictionary
.
Lookup
(
k3ByteCharacter
,
3
,
-
123
),
-
123
);
EXPECT_EQ
(
dictionary
.
Lookup
(
k4ByteCharacter
,
4
,
-
123
),
-
123
);
}
// Tests that the dictionary can be reset to a copy of a term map.
TEST
(
UnicodeDictionaryTest
,
Reset
)
{
TermFrequencyMap
character_map
;
ASSERT_EQ
(
character_map
.
Increment
(
k1ByteCharacter
),
0
);
ASSERT_EQ
(
character_map
.
Increment
(
k2ByteCharacter
),
1
);
ASSERT_EQ
(
character_map
.
Increment
(
k3ByteCharacter
),
2
);
ASSERT_EQ
(
character_map
.
Increment
(
k4ByteCharacter
),
3
);
UnicodeDictionary
dictionary
;
TF_ASSERT_OK
(
dictionary
.
Reset
(
character_map
));
EXPECT_EQ
(
dictionary
.
size
(),
4
);
EXPECT_EQ
(
dictionary
.
Lookup
(
k1ByteCharacter
,
1
,
-
123
),
0
);
EXPECT_EQ
(
dictionary
.
Lookup
(
k2ByteCharacter
,
2
,
-
123
),
1
);
EXPECT_EQ
(
dictionary
.
Lookup
(
k3ByteCharacter
,
3
,
-
123
),
2
);
EXPECT_EQ
(
dictionary
.
Lookup
(
k4ByteCharacter
,
4
,
-
123
),
3
);
}
// Tests that the dictionary fails if a character is empty.
TEST
(
UnicodeDictionaryTest
,
EmptyCharacter
)
{
TermFrequencyMap
character_map
;
ASSERT_EQ
(
character_map
.
Increment
(
""
),
0
);
UnicodeDictionary
dictionary
;
EXPECT_THAT
(
dictionary
.
Reset
(
character_map
),
test
::
IsErrorWithSubstr
(
"Term 0 is empty"
));
}
// Tests that the dictionary fails if a term contains more than one character.
TEST
(
UnicodeDictionaryTest
,
MultipleCharacters
)
{
TermFrequencyMap
character_map
;
ASSERT_EQ
(
character_map
.
Increment
(
"1234"
),
0
);
UnicodeDictionary
dictionary
;
EXPECT_THAT
(
dictionary
.
Reset
(
character_map
),
test
::
IsErrorWithSubstr
(
"Term 0 should have size 1"
));
}
// Tests that the dictionary fails if a character is invalid.
TEST
(
UnicodeDictionaryTest
,
InvalidUtf8
)
{
TermFrequencyMap
character_map
;
ASSERT_EQ
(
character_map
.
Increment
(
kInvalidUtf8
),
0
);
UnicodeDictionary
dictionary
;
EXPECT_THAT
(
dictionary
.
Reset
(
character_map
),
test
::
IsErrorWithSubstr
(
"Term 0 is not valid UTF-8"
));
}
// Tests that the dictionary can be constructed from a file.
TEST
(
UnicodeDictionaryTest
,
ConstructFromFile
)
{
// Recall that terms are loaded in order of descending frequency.
const
string
character_map_path
=
WriteTermMap
({{
"too-infrequent"
,
1
},
{
k1ByteCharacter
,
2
},
{
k2ByteCharacter
,
3
},
{
k3ByteCharacter
,
4
},
{
k4ByteCharacter
,
5
}});
const
UnicodeDictionary
dictionary
(
character_map_path
,
2
,
0
);
EXPECT_EQ
(
dictionary
.
size
(),
4
);
EXPECT_EQ
(
dictionary
.
Lookup
(
k1ByteCharacter
,
1
,
-
123
),
3
);
EXPECT_EQ
(
dictionary
.
Lookup
(
k2ByteCharacter
,
2
,
-
123
),
2
);
EXPECT_EQ
(
dictionary
.
Lookup
(
k3ByteCharacter
,
3
,
-
123
),
1
);
EXPECT_EQ
(
dictionary
.
Lookup
(
k4ByteCharacter
,
4
,
-
123
),
0
);
}
// Tests that the dictionary constructor dies on error.
TEST
(
UnicodeDictionaryTest
,
ConstructorDiesOnError
)
{
const
string
bad_path
=
WriteTermMap
({{
"1234"
,
1
}});
EXPECT_DEATH
(
UnicodeDictionary
dictionary
(
bad_path
,
0
,
0
),
"Term 0 should have size 1"
);
}
// Tests that the dictionary can map all valid codepoints.
TEST
(
UnicodeDictionaryTest
,
AllValidCodepoints
)
{
TermFrequencyMap
character_map
;
for
(
Rune
rune
=
0
;
rune
<
Runemax
;
++
rune
)
{
// Some codepoints are considered invalid, and UnicodeDictionary::Reset()
// will fail if it encounters them (see the InvalidUtf8 test). Skip those
// since we've already tested this in the "InvalidUtf8" test.
if
(
!
UniLib
::
IsValidCodepoint
(
rune
))
continue
;
char
data
[
UTFmax
];
const
int
size
=
runetochar
(
data
,
&
rune
);
const
string
character
(
data
,
size
);
const
int
index
=
character_map
.
Size
();
ASSERT_EQ
(
character_map
.
Increment
(
character
),
index
);
}
UnicodeDictionary
dictionary
;
TF_ASSERT_OK
(
dictionary
.
Reset
(
character_map
));
for
(
int
index
=
0
;
index
<
character_map
.
Size
();
++
index
)
{
const
string
&
character
=
character_map
.
GetTerm
(
index
);
EXPECT_EQ
(
dictionary
.
Lookup
(
character
.
data
(),
character
.
size
(),
-
1
),
index
);
}
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/variable_store.h
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_VARIABLE_STORE_H_
#define DRAGNN_RUNTIME_VARIABLE_STORE_H_
#include <string>
#include <vector>
#include "dragnn/protos/runtime.pb.h"
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/math/types.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// Interface for a store holding named, precomputed variables. Implementations
// must be thread-compatible.
class
VariableStore
{
public:
VariableStore
(
const
VariableStore
&
that
)
=
delete
;
VariableStore
&
operator
=
(
const
VariableStore
&
that
)
=
delete
;
virtual
~
VariableStore
()
=
default
;
// Looks for the variable with the |name|, formats its content according to
// the requested |format| (see details below), and points the |area| at the
// result. The content of the variable before formatting is its content in
// the Python codebase. The |area| is valid while this lives, even after
// Close(). On error, returns non-OK and modifies nothing.
//
// Upon success the output |dimensions| will be cleared and assigned to
// the set of dimensions (num_elements,) in case of vectors, (num_rows,
// num_columns) in case of regular matrices, and (num_rows, num_columns,
// block_size) in case of blocked matrices.
//
// FORMAT_FLAT:
// Flattens the variable as if by tf.reshape(var, [-1]), and sets the |area|
// to a single sub-view that points at the flat array.
//
// FORMAT_ROW_MAJOR_MATRIX:
// Reshapes the variable into a matrix as if by tf.reshape(var, [-1, D]),
// where D is the variable's innermost dimension. Points each sub-view of
// the |area| at the corresponding row of the formatted matrix. Requires
// that the variable has rank at least 2.
//
// FORMAT_COLUMN_BLOCKED_ROW_MAJOR_MATRIX:
// The variable must have shape [num_sub_matrices, num_rows, block_size],
// and is imported as a column-blocked row-major matrix, as documented in
// BlockedMatrixFormat (in math/types.h). The matrix may also be padded.
virtual
tensorflow
::
Status
Lookup
(
const
string
&
name
,
VariableSpec
::
Format
format
,
std
::
vector
<
size_t
>
*
dimensions
,
AlignedArea
*
area
)
=
0
;
// Looks up a FORMAT_FLAT variable as a Vector<T>.
template
<
class
T
>
tensorflow
::
Status
Lookup
(
const
string
&
name
,
Vector
<
T
>
*
vector
);
// Looks up a FORMAT_ROW_MAJOR_MATRIX as a Matrix<T>.
template
<
class
T
>
tensorflow
::
Status
Lookup
(
const
string
&
name
,
Matrix
<
T
>
*
matrix
);
// Looks up a FORMAT_COLUMN_BLOCKED_ROW_MAJOR_MATRIX as a BlockedMatrix<T>.
template
<
class
T
>
tensorflow
::
Status
Lookup
(
const
string
&
name
,
BlockedMatrix
<
T
>
*
matrix
);
// Releases intermediate resources, if any. Does not invalidate the contents
// of variables returned by previous calls to Lookup*(), but future calls to
// Lookup*() are unsupported. On error, returns non-OK.
virtual
tensorflow
::
Status
Close
()
=
0
;
protected:
VariableStore
()
=
default
;
};
// Implementation details below.
template
<
class
T
>
tensorflow
::
Status
VariableStore
::
Lookup
(
const
string
&
name
,
Vector
<
T
>
*
vector
)
{
AlignedArea
area
;
std
::
vector
<
size_t
>
dimensions
;
TF_RETURN_IF_ERROR
(
Lookup
(
name
,
VariableSpec
::
FORMAT_FLAT
,
&
dimensions
,
&
area
));
if
(
area
.
num_views
()
!=
1
)
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Vector variable '"
,
name
,
"' should have 1 sub-view but has "
,
area
.
num_views
());
}
if
(
area
.
view_size
()
%
sizeof
(
T
)
!=
0
)
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Vector variable '"
,
name
,
"' does not divide into elements of size "
,
sizeof
(
T
));
}
*
vector
=
Vector
<
T
>
(
area
.
view
(
0
));
if
(
dimensions
.
size
()
!=
1
)
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Expected 1 dimensions, got "
,
dimensions
.
size
());
}
if
(
dimensions
[
0
]
!=
vector
->
size
())
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Vector size ("
,
vector
->
size
(),
") disagrees with dimensions[0] ("
,
dimensions
[
0
],
")"
);
}
return
tensorflow
::
Status
::
OK
();
}
template
<
class
T
>
tensorflow
::
Status
VariableStore
::
Lookup
(
const
string
&
name
,
Matrix
<
T
>
*
matrix
)
{
AlignedArea
area
;
std
::
vector
<
size_t
>
dimensions
;
TF_RETURN_IF_ERROR
(
Lookup
(
name
,
VariableSpec
::
FORMAT_ROW_MAJOR_MATRIX
,
&
dimensions
,
&
area
));
if
(
dimensions
.
size
()
!=
2
)
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Expected 2 dimensions, got "
,
dimensions
.
size
());
}
if
(
area
.
view_size
()
%
sizeof
(
T
)
!=
0
)
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Matrix variable '"
,
name
,
"' does not divide into elements of size "
,
sizeof
(
T
));
}
*
matrix
=
Matrix
<
T
>
(
area
);
if
(
dimensions
[
0
]
!=
matrix
->
num_rows
())
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Matrix rows ("
,
matrix
->
num_rows
(),
") disagrees with dimensions[0] ("
,
dimensions
[
0
],
")"
);
}
if
(
dimensions
[
1
]
!=
matrix
->
num_columns
())
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Matrix columns ("
,
matrix
->
num_columns
(),
") disagrees with dimensions[1] ("
,
dimensions
[
1
],
")"
);
}
return
tensorflow
::
Status
::
OK
();
}
template
<
class
T
>
tensorflow
::
Status
VariableStore
::
Lookup
(
const
string
&
name
,
BlockedMatrix
<
T
>
*
matrix
)
{
AlignedArea
area
;
std
::
vector
<
size_t
>
dimensions
;
TF_RETURN_IF_ERROR
(
Lookup
(
name
,
VariableSpec
::
FORMAT_COLUMN_BLOCKED_ROW_MAJOR_MATRIX
,
&
dimensions
,
&
area
));
if
(
dimensions
.
size
()
!=
3
)
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Expected 3 dimensions, got "
,
dimensions
.
size
());
}
const
size_t
num_rows
=
dimensions
[
0
];
const
size_t
num_columns
=
dimensions
[
1
];
const
size_t
block_size
=
dimensions
[
2
];
if
(
area
.
view_size
()
!=
block_size
*
sizeof
(
T
))
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Area view size ("
,
area
.
view_size
(),
") doesn't correspond to block size ("
,
block_size
,
") times data type size ("
,
sizeof
(
T
),
")"
);
}
if
(
num_rows
*
num_columns
!=
area
.
num_views
()
*
block_size
)
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Rows * cols ("
,
num_rows
*
num_columns
,
") != area view size ("
,
area
.
num_views
()
*
block_size
,
")"
);
}
// Avoid modification on error.
BlockedMatrix
<
T
>
local_matrix
;
TF_RETURN_IF_ERROR
(
local_matrix
.
Reset
(
area
,
num_rows
,
num_columns
));
*
matrix
=
local_matrix
;
return
tensorflow
::
Status
::
OK
();
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_VARIABLE_STORE_H_
Prev
1
…
9
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment