Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
edea2b67
Commit
edea2b67
authored
May 11, 2018
by
Terry Koo
Browse files
Remove runtime because reasons.
parent
a4bb31d0
Changes
291
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
2718 deletions
+0
-2718
research/syntaxnet/dragnn/runtime/syntaxnet_tag_sequence_predictor.cc
...taxnet/dragnn/runtime/syntaxnet_tag_sequence_predictor.cc
+0
-130
research/syntaxnet/dragnn/runtime/syntaxnet_tag_sequence_predictor_test.cc
...t/dragnn/runtime/syntaxnet_tag_sequence_predictor_test.cc
+0
-245
research/syntaxnet/dragnn/runtime/syntaxnet_word_sequence_extractor.cc
...axnet/dragnn/runtime/syntaxnet_word_sequence_extractor.cc
+0
-132
research/syntaxnet/dragnn/runtime/syntaxnet_word_sequence_extractor_test.cc
.../dragnn/runtime/syntaxnet_word_sequence_extractor_test.cc
+0
-219
research/syntaxnet/dragnn/runtime/term_map_sequence_extractor.h
...ch/syntaxnet/dragnn/runtime/term_map_sequence_extractor.h
+0
-114
research/syntaxnet/dragnn/runtime/term_map_sequence_extractor_test.cc
...taxnet/dragnn/runtime/term_map_sequence_extractor_test.cc
+0
-153
research/syntaxnet/dragnn/runtime/term_map_sequence_predictor.cc
...h/syntaxnet/dragnn/runtime/term_map_sequence_predictor.cc
+0
-59
research/syntaxnet/dragnn/runtime/term_map_sequence_predictor.h
...ch/syntaxnet/dragnn/runtime/term_map_sequence_predictor.h
+0
-66
research/syntaxnet/dragnn/runtime/term_map_sequence_predictor_test.cc
...taxnet/dragnn/runtime/term_map_sequence_predictor_test.cc
+0
-119
research/syntaxnet/dragnn/runtime/term_map_utils.cc
research/syntaxnet/dragnn/runtime/term_map_utils.cc
+0
-77
research/syntaxnet/dragnn/runtime/term_map_utils.h
research/syntaxnet/dragnn/runtime/term_map_utils.h
+0
-47
research/syntaxnet/dragnn/runtime/term_map_utils_test.cc
research/syntaxnet/dragnn/runtime/term_map_utils_test.cc
+0
-192
research/syntaxnet/dragnn/runtime/test/BUILD
research/syntaxnet/dragnn/runtime/test/BUILD
+0
-110
research/syntaxnet/dragnn/runtime/test/fake_variable_store.cc
...arch/syntaxnet/dragnn/runtime/test/fake_variable_store.cc
+0
-128
research/syntaxnet/dragnn/runtime/test/fake_variable_store.h
research/syntaxnet/dragnn/runtime/test/fake_variable_store.h
+0
-113
research/syntaxnet/dragnn/runtime/test/fake_variable_store_test.cc
...syntaxnet/dragnn/runtime/test/fake_variable_store_test.cc
+0
-199
research/syntaxnet/dragnn/runtime/test/helpers.cc
research/syntaxnet/dragnn/runtime/test/helpers.cc
+0
-81
research/syntaxnet/dragnn/runtime/test/helpers.h
research/syntaxnet/dragnn/runtime/test/helpers.h
+0
-179
research/syntaxnet/dragnn/runtime/test/helpers_test.cc
research/syntaxnet/dragnn/runtime/test/helpers_test.cc
+0
-151
research/syntaxnet/dragnn/runtime/test/network_test_base.cc
research/syntaxnet/dragnn/runtime/test/network_test_base.cc
+0
-204
No files found.
Too many changes to show.
To preserve performance only
291 of 291+
files are displayed.
Plain diff
Email patch
research/syntaxnet/dragnn/runtime/syntaxnet_tag_sequence_predictor.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <algorithm>
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/io/sentence_input_batch.h"
#include "dragnn/io/syntaxnet_sentence.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/sequence_predictor.h"
#include "dragnn/runtime/term_map_sequence_predictor.h"
#include "dragnn/runtime/transition_system_traits.h"
#include "syntaxnet/base.h"
#include "syntaxnet/sentence.pb.h"
#include "syntaxnet/term_frequency_map.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Predicts sequences of POS tags in SyntaxNetComponent batches.
class
SyntaxNetTagSequencePredictor
:
public
TermMapSequencePredictor
{
public:
SyntaxNetTagSequencePredictor
();
// Implements SequencePredictor.
bool
Supports
(
const
ComponentSpec
&
component_spec
)
const
override
;
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
component_spec
)
override
;
tensorflow
::
Status
Predict
(
Matrix
<
float
>
logits
,
InputBatchCache
*
input
)
const
override
;
private:
// Whether to process sequences from left to right.
bool
left_to_right_
=
true
;
};
SyntaxNetTagSequencePredictor
::
SyntaxNetTagSequencePredictor
()
:
TermMapSequencePredictor
(
"tag-map"
)
{}
bool
SyntaxNetTagSequencePredictor
::
Supports
(
const
ComponentSpec
&
component_spec
)
const
{
return
TermMapSequencePredictor
::
SupportsTermMap
(
component_spec
)
&&
component_spec
.
backend
().
registered_name
()
==
"SyntaxNetComponent"
&&
component_spec
.
transition_system
().
registered_name
()
==
"tagger"
;
}
tensorflow
::
Status
SyntaxNetTagSequencePredictor
::
Initialize
(
const
ComponentSpec
&
component_spec
)
{
// Load all tags.
constexpr
int
kMinFrequency
=
0
;
constexpr
int
kMaxNumTerms
=
0
;
TF_RETURN_IF_ERROR
(
TermMapSequencePredictor
::
InitializeTermMap
(
component_spec
,
kMinFrequency
,
kMaxNumTerms
));
if
(
term_map
().
Size
()
==
0
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Empty tag map"
);
}
const
int
map_num_tags
=
term_map
().
Size
();
const
int
spec_num_tags
=
component_spec
.
num_actions
();
if
(
map_num_tags
!=
spec_num_tags
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Tag count mismatch between term map ("
,
map_num_tags
,
") and ComponentSpec ("
,
spec_num_tags
,
")"
);
}
left_to_right_
=
TransitionSystemTraits
(
component_spec
).
is_left_to_right
;
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
SyntaxNetTagSequencePredictor
::
Predict
(
Matrix
<
float
>
logits
,
InputBatchCache
*
input
)
const
{
if
(
logits
.
num_columns
()
!=
term_map
().
Size
())
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Logits shape mismatch: expected "
,
term_map
().
Size
(),
" columns but got "
,
logits
.
num_columns
());
}
const
std
::
vector
<
SyntaxNetSentence
>
&
data
=
*
input
->
GetAs
<
SentenceInputBatch
>
()
->
data
();
if
(
data
.
size
()
!=
1
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Non-singleton batch: got "
,
data
.
size
(),
" elements"
);
}
Sentence
*
sentence
=
data
[
0
].
sentence
();
const
int
num_tokens
=
sentence
->
token_size
();
if
(
logits
.
num_rows
()
!=
num_tokens
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Logits shape mismatch: expected "
,
num_tokens
,
" rows but got "
,
logits
.
num_rows
());
}
int
token_index
=
left_to_right_
?
0
:
num_tokens
-
1
;
const
int
token_increment
=
left_to_right_
?
1
:
-
1
;
for
(
int
i
=
0
;
i
<
num_tokens
;
++
i
,
token_index
+=
token_increment
)
{
const
Vector
<
float
>
row
=
logits
.
row
(
i
);
Token
*
token
=
sentence
->
mutable_token
(
token_index
);
const
float
*
const
begin
=
row
.
begin
();
const
float
*
const
end
=
row
.
end
();
token
->
set_tag
(
term_map
().
GetTerm
(
std
::
max_element
(
begin
,
end
)
-
begin
));
}
return
tensorflow
::
Status
::
OK
();
}
DRAGNN_RUNTIME_REGISTER_SEQUENCE_PREDICTOR
(
SyntaxNetTagSequencePredictor
);
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/syntaxnet_tag_sequence_predictor_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/sequence_predictor.h"
#include "dragnn/runtime/test/helpers.h"
#include "dragnn/runtime/test/term_map_helpers.h"
#include "syntaxnet/base.h"
#include "syntaxnet/sentence.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
constexpr
char
kResourceName
[]
=
"tag-map"
;
// Writes a default tag map and returns a path to it.
string
GetTagMapPath
()
{
static
string
*
const
kPath
=
new
string
(
WriteTermMap
({{
"NOUN"
,
3
},
{
"VERB"
,
2
},
{
"DET"
,
1
}}));
return
*
kPath
;
}
// Returns a ComponentSpec parsed from the |text| that contains a term map
// resource pointing at the |path|.
ComponentSpec
MakeSpec
(
const
string
&
text
,
const
string
&
path
)
{
ComponentSpec
component_spec
;
CHECK
(
TextFormat
::
ParseFromString
(
text
,
&
component_spec
));
AddTermMapResource
(
kResourceName
,
path
,
&
component_spec
);
return
component_spec
;
}
// Returns a ComponentSpec that the predictor will support.
ComponentSpec
MakeSupportedSpec
()
{
return
MakeSpec
(
R"(transition_system { registered_name: 'tagger' }
backend { registered_name: 'SyntaxNetComponent' }
num_actions: 3)"
,
GetTagMapPath
());
}
// Returns per-token tag logits.
UniqueMatrix
<
float
>
MakeLogits
()
{
return
UniqueMatrix
<
float
>
({{
0.0
,
0.0
,
1.0
},
// predict 2 = DET
{
1.0
,
0.0
,
0.0
},
// predict 0 = NOUN
{
0.0
,
1.0
,
0.0
},
// predict 1 = VERB
{
0.0
,
0.0
,
1.0
},
// predict 2 = DET
{
1.0
,
0.0
,
0.0
}});
// predict 0 = NOUN
}
// Returns a default sentence.
Sentence
MakeSentence
()
{
Sentence
sentence
;
for
(
const
string
&
word
:
{
"the"
,
"cat"
,
"chased"
,
"a"
,
"mouse"
})
{
Token
*
token
=
sentence
.
add_token
();
token
->
set_start
(
0
);
// never used; set because required field
token
->
set_end
(
0
);
// never used; set because required field
token
->
set_word
(
word
);
}
return
sentence
;
}
// Tests that the predictor supports an appropriate spec.
TEST
(
SyntaxNetTagSequencePredictorTest
,
Supported
)
{
const
ComponentSpec
component_spec
=
MakeSupportedSpec
();
string
name
;
TF_ASSERT_OK
(
SequencePredictor
::
Select
(
component_spec
,
&
name
));
EXPECT_EQ
(
name
,
"SyntaxNetTagSequencePredictor"
);
}
// Tests that the predictor requires the proper backend.
TEST
(
SyntaxNetTagSequencePredictorTest
,
WrongBackend
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
mutable_backend
()
->
set_registered_name
(
"bad"
);
string
name
;
EXPECT_THAT
(
SequencePredictor
::
Select
(
component_spec
,
&
name
),
test
::
IsErrorWithSubstr
(
"No SequencePredictor supports ComponentSpec"
));
}
// Tests that the predictor requires the proper transition system.
TEST
(
SyntaxNetTagSequencePredictorTest
,
WrongTransitionSystem
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
mutable_transition_system
()
->
set_registered_name
(
"bad"
);
string
name
;
EXPECT_THAT
(
SequencePredictor
::
Select
(
component_spec
,
&
name
),
test
::
IsErrorWithSubstr
(
"No SequencePredictor supports ComponentSpec"
));
}
// Tests that the predictor can be initialized and used to add POS tags to a
// sentence.
TEST
(
SyntaxNetTagSequencePredictorTest
,
InitializeAndPredict
)
{
const
ComponentSpec
component_spec
=
MakeSupportedSpec
();
std
::
unique_ptr
<
SequencePredictor
>
predictor
;
TF_ASSERT_OK
(
SequencePredictor
::
New
(
"SyntaxNetTagSequencePredictor"
,
component_spec
,
&
predictor
));
UniqueMatrix
<
float
>
logits
=
MakeLogits
();
const
Sentence
sentence
=
MakeSentence
();
InputBatchCache
input
(
sentence
.
SerializeAsString
());
TF_ASSERT_OK
(
predictor
->
Predict
(
Matrix
<
float
>
(
*
logits
),
&
input
));
const
std
::
vector
<
string
>
predictions
=
input
.
SerializedData
();
ASSERT_EQ
(
predictions
.
size
(),
1
);
Sentence
tagged
;
ASSERT_TRUE
(
tagged
.
ParseFromString
(
predictions
[
0
]));
ASSERT_EQ
(
tagged
.
token_size
(),
5
);
EXPECT_EQ
(
tagged
.
token
(
0
).
tag
(),
"DET"
);
// the
EXPECT_EQ
(
tagged
.
token
(
1
).
tag
(),
"NOUN"
);
// cat
EXPECT_EQ
(
tagged
.
token
(
2
).
tag
(),
"VERB"
);
// chased
EXPECT_EQ
(
tagged
.
token
(
3
).
tag
(),
"DET"
);
// a
EXPECT_EQ
(
tagged
.
token
(
4
).
tag
(),
"NOUN"
);
// mouse
}
// Tests that the predictor works on an empty sentence.
TEST
(
SyntaxNetTagSequencePredictorTest
,
EmptySentence
)
{
const
ComponentSpec
component_spec
=
MakeSupportedSpec
();
std
::
unique_ptr
<
SequencePredictor
>
predictor
;
TF_ASSERT_OK
(
SequencePredictor
::
New
(
"SyntaxNetTagSequencePredictor"
,
component_spec
,
&
predictor
));
AlignedView
view
;
AlignedArea
area
;
TF_ASSERT_OK
(
area
.
Reset
(
view
,
0
,
3
*
sizeof
(
float
)));
Matrix
<
float
>
logits
(
area
);
const
Sentence
sentence
;
InputBatchCache
input
(
sentence
.
SerializeAsString
());
TF_ASSERT_OK
(
predictor
->
Predict
(
logits
,
&
input
));
const
std
::
vector
<
string
>
predictions
=
input
.
SerializedData
();
ASSERT_EQ
(
predictions
.
size
(),
1
);
Sentence
tagged
;
ASSERT_TRUE
(
tagged
.
ParseFromString
(
predictions
[
0
]));
ASSERT_EQ
(
tagged
.
token_size
(),
0
);
}
// Tests that the predictor fails on an empty term map.
TEST
(
SyntaxNetTagSequencePredictorTest
,
EmptyTermMap
)
{
const
string
path
=
WriteTermMap
({});
const
ComponentSpec
component_spec
=
MakeSpec
(
""
,
path
);
std
::
unique_ptr
<
SequencePredictor
>
predictor
;
EXPECT_THAT
(
SequencePredictor
::
New
(
"SyntaxNetTagSequencePredictor"
,
component_spec
,
&
predictor
),
test
::
IsErrorWithSubstr
(
"Empty tag map"
));
}
// Tests that Predict() fails if the batch is the wrong size.
TEST
(
SyntaxNetTagSequencePredictorTest
,
WrongBatchSize
)
{
const
ComponentSpec
component_spec
=
MakeSupportedSpec
();
std
::
unique_ptr
<
SequencePredictor
>
predictor
;
TF_ASSERT_OK
(
SequencePredictor
::
New
(
"SyntaxNetTagSequencePredictor"
,
component_spec
,
&
predictor
));
UniqueMatrix
<
float
>
logits
=
MakeLogits
();
const
Sentence
sentence
=
MakeSentence
();
const
std
::
vector
<
string
>
data
=
{
sentence
.
SerializeAsString
(),
sentence
.
SerializeAsString
()};
InputBatchCache
input
(
data
);
EXPECT_THAT
(
predictor
->
Predict
(
Matrix
<
float
>
(
*
logits
),
&
input
),
test
::
IsErrorWithSubstr
(
"Non-singleton batch: got 2 elements"
));
}
// Tests that Initialize() fails if the term map doesn't match the specified
// number of actions.
TEST
(
SyntaxNetTagSequencePredictorTest
,
WrongNumActions
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
set_num_actions
(
1000
);
std
::
unique_ptr
<
SequencePredictor
>
predictor
;
EXPECT_THAT
(
SequencePredictor
::
New
(
"SyntaxNetTagSequencePredictor"
,
component_spec
,
&
predictor
),
test
::
IsErrorWithSubstr
(
"Tag count mismatch between term map (3) and ComponentSpec (1000)"
));
}
// Tests that Predict() fails if the logits don't match the term map.
TEST
(
SyntaxNetTagSequencePredictorTest
,
WrongLogitsColumns
)
{
const
string
path
=
WriteTermMap
({{
"a"
,
1
},
{
"b"
,
1
}});
const
ComponentSpec
component_spec
=
MakeSpec
(
"num_actions: 2"
,
path
);
std
::
unique_ptr
<
SequencePredictor
>
predictor
;
TF_ASSERT_OK
(
SequencePredictor
::
New
(
"SyntaxNetTagSequencePredictor"
,
component_spec
,
&
predictor
));
UniqueMatrix
<
float
>
logits
=
MakeLogits
();
Sentence
sentence
=
MakeSentence
();
InputBatchCache
input
(
sentence
.
SerializeAsString
());
EXPECT_THAT
(
predictor
->
Predict
(
Matrix
<
float
>
(
*
logits
),
&
input
),
test
::
IsErrorWithSubstr
(
"Logits shape mismatch: expected 2 columns but got 3"
));
}
// Tests that Predict() fails if the logits don't match the number of tokens.
TEST
(
SyntaxNetTagSequencePredictorTest
,
WrongLogitsRows
)
{
const
ComponentSpec
component_spec
=
MakeSupportedSpec
();
std
::
unique_ptr
<
SequencePredictor
>
predictor
;
TF_ASSERT_OK
(
SequencePredictor
::
New
(
"SyntaxNetTagSequencePredictor"
,
component_spec
,
&
predictor
));
UniqueMatrix
<
float
>
logits
=
MakeLogits
();
Sentence
sentence
=
MakeSentence
();
sentence
.
mutable_token
()
->
RemoveLast
();
// bad
InputBatchCache
input
(
sentence
.
SerializeAsString
());
EXPECT_THAT
(
predictor
->
Predict
(
Matrix
<
float
>
(
*
logits
),
&
input
),
test
::
IsErrorWithSubstr
(
"Logits shape mismatch: expected 4 rows but got 5"
));
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/syntaxnet_word_sequence_extractor.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/io/sentence_input_batch.h"
#include "dragnn/io/syntaxnet_sentence.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/sequence_extractor.h"
#include "dragnn/runtime/term_map_sequence_extractor.h"
#include "dragnn/runtime/term_map_utils.h"
#include "dragnn/runtime/transition_system_traits.h"
#include "syntaxnet/base.h"
#include "syntaxnet/term_frequency_map.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Sequence extractor that extracts words from a SyntaxNetComponent batch.
class
SyntaxNetWordSequenceExtractor
:
public
TermMapSequenceExtractor
<
TermFrequencyMap
>
{
public:
SyntaxNetWordSequenceExtractor
();
// Implements SequenceExtractor.
bool
Supports
(
const
FixedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
)
const
override
;
tensorflow
::
Status
Initialize
(
const
FixedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
)
override
;
tensorflow
::
Status
GetIds
(
InputBatchCache
*
input
,
std
::
vector
<
int32
>
*
ids
)
const
override
;
private:
// Parses |fml| and sets |min_frequency| and |max_num_terms| to the specified
// values. If the |fml| does not specify a supported feature, returns non-OK
// and modifies nothing.
static
tensorflow
::
Status
ParseFml
(
const
string
&
fml
,
int
*
min_frequency
,
int
*
max_num_terms
);
// Feature ID for unknown words.
int32
unknown_id_
=
-
1
;
};
SyntaxNetWordSequenceExtractor
::
SyntaxNetWordSequenceExtractor
()
:
TermMapSequenceExtractor
(
"word-map"
)
{}
tensorflow
::
Status
SyntaxNetWordSequenceExtractor
::
ParseFml
(
const
string
&
fml
,
int
*
min_frequency
,
int
*
max_num_terms
)
{
return
ParseTermMapFml
(
fml
,
{
"input"
,
"token"
,
"word"
},
min_frequency
,
max_num_terms
);
}
bool
SyntaxNetWordSequenceExtractor
::
Supports
(
const
FixedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
)
const
{
TransitionSystemTraits
traits
(
component_spec
);
int
unused_min_frequency
=
0
;
int
unused_max_num_terms
=
0
;
const
tensorflow
::
Status
parse_fml_status
=
ParseFml
(
channel
.
fml
(),
&
unused_min_frequency
,
&
unused_max_num_terms
);
return
TermMapSequenceExtractor
::
SupportsTermMap
(
channel
,
component_spec
)
&&
parse_fml_status
.
ok
()
&&
component_spec
.
backend
().
registered_name
()
==
"SyntaxNetComponent"
&&
traits
.
is_sequential
&&
traits
.
is_token_scale
;
}
tensorflow
::
Status
SyntaxNetWordSequenceExtractor
::
Initialize
(
const
FixedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
)
{
int
min_frequency
=
0
;
int
max_num_terms
=
0
;
TF_RETURN_IF_ERROR
(
ParseFml
(
channel
.
fml
(),
&
min_frequency
,
&
max_num_terms
));
TF_RETURN_IF_ERROR
(
TermMapSequenceExtractor
::
InitializeTermMap
(
channel
,
component_spec
,
min_frequency
,
max_num_terms
));
unknown_id_
=
term_map
().
Size
();
const
int
outside_id
=
unknown_id_
+
1
;
const
int
map_vocab_size
=
outside_id
+
1
;
const
int
spec_vocab_size
=
channel
.
vocabulary_size
();
if
(
map_vocab_size
!=
spec_vocab_size
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Word vocabulary size mismatch between term map ("
,
map_vocab_size
,
") and ComponentSpec ("
,
spec_vocab_size
,
")"
);
}
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
SyntaxNetWordSequenceExtractor
::
GetIds
(
InputBatchCache
*
input
,
std
::
vector
<
int32
>
*
ids
)
const
{
ids
->
clear
();
const
std
::
vector
<
SyntaxNetSentence
>
&
data
=
*
input
->
GetAs
<
SentenceInputBatch
>
()
->
data
();
if
(
data
.
size
()
!=
1
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Non-singleton batch: got "
,
data
.
size
(),
" elements"
);
}
const
Sentence
&
sentence
=
*
data
[
0
].
sentence
();
for
(
const
Token
&
token
:
sentence
.
token
())
{
ids
->
push_back
(
term_map
().
LookupIndex
(
token
.
word
(),
unknown_id_
));
}
return
tensorflow
::
Status
::
OK
();
}
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR
(
SyntaxNetWordSequenceExtractor
);
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/syntaxnet_word_sequence_extractor_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/sequence_extractor.h"
#include "dragnn/runtime/test/term_map_helpers.h"
#include "syntaxnet/base.h"
#include "syntaxnet/sentence.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
constexpr
char
kResourceName
[]
=
"word-map"
;
// Returns a ComponentSpec parsed from the |text| that contains a term map
// resource pointing at the |path|.
ComponentSpec
MakeSpec
(
const
string
&
text
,
const
string
&
path
)
{
ComponentSpec
component_spec
;
CHECK
(
TextFormat
::
ParseFromString
(
text
,
&
component_spec
));
AddTermMapResource
(
kResourceName
,
path
,
&
component_spec
);
return
component_spec
;
}
// Returns a ComponentSpec that the extractor will support.
ComponentSpec
MakeSupportedSpec
()
{
return
MakeSpec
(
R"(transition_system { registered_name: 'shift-only' }
backend { registered_name: 'SyntaxNetComponent' }
fixed_feature {} # breaks hard-coded refs to channel 0
fixed_feature { size: 1 fml: 'input.token.word(min-freq=2)' })"
,
"/dev/null"
);
}
// Returns a default sentence.
Sentence
MakeSentence
()
{
Sentence
sentence
;
for
(
const
string
&
word
:
{
"a"
,
"bc"
,
"def"
})
{
Token
*
token
=
sentence
.
add_token
();
token
->
set_start
(
0
);
// never used; set because required field
token
->
set_end
(
0
);
// never used; set because required field
token
->
set_word
(
word
);
}
return
sentence
;
}
// Tests that the extractor supports an appropriate spec.
TEST
(
SyntaxNetWordSequenceExtractorTest
,
Supported
)
{
const
ComponentSpec
component_spec
=
MakeSupportedSpec
();
const
FixedFeatureChannel
&
channel
=
component_spec
.
fixed_feature
(
1
);
string
name
;
TF_ASSERT_OK
(
SequenceExtractor
::
Select
(
channel
,
component_spec
,
&
name
));
EXPECT_EQ
(
name
,
"SyntaxNetWordSequenceExtractor"
);
}
// Tests that the extractor requires the proper backend.
TEST
(
SyntaxNetWordSequenceExtractorTest
,
WrongBackend
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
mutable_backend
()
->
set_registered_name
(
"bad"
);
const
FixedFeatureChannel
&
channel
=
component_spec
.
fixed_feature
(
1
);
string
name
;
EXPECT_THAT
(
SequenceExtractor
::
Select
(
channel
,
component_spec
,
&
name
),
test
::
IsErrorWithSubstr
(
"No SequenceExtractor supports channel"
));
}
// Tests that the extractor requires the proper transition system.
TEST
(
SyntaxNetWordSequenceExtractorTest
,
WrongTransitionSystem
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
mutable_transition_system
()
->
set_registered_name
(
"bad"
);
const
FixedFeatureChannel
&
channel
=
component_spec
.
fixed_feature
(
1
);
string
name
;
EXPECT_THAT
(
SequenceExtractor
::
Select
(
channel
,
component_spec
,
&
name
),
test
::
IsErrorWithSubstr
(
"No SequenceExtractor supports channel"
));
}
// Expects that the |fml| is rejected by the extractor.
void
ExpectRejectedFml
(
const
string
&
fml
)
{
ComponentSpec
component_spec
=
MakeSupportedSpec
();
component_spec
.
mutable_fixed_feature
(
1
)
->
set_fml
(
fml
);
const
FixedFeatureChannel
&
channel
=
component_spec
.
fixed_feature
(
1
);
string
name
;
EXPECT_THAT
(
SequenceExtractor
::
Select
(
channel
,
component_spec
,
&
name
),
test
::
IsErrorWithSubstr
(
"No SequenceExtractor supports channel"
));
}
// Tests that the extractor requires the proper FML.
TEST
(
SyntaxNetWordSequenceExtractorTest
,
WrongFml
)
{
ExpectRejectedFml
(
"bad"
);
EXPECT_DEATH
(
ExpectRejectedFml
(
"input.token.word("
),
"Error in feature model"
);
EXPECT_DEATH
(
ExpectRejectedFml
(
"input.token.word()"
),
"Error in feature model"
);
ExpectRejectedFml
(
"input.token.word(10)"
);
EXPECT_DEATH
(
ExpectRejectedFml
(
"input.token.word(min-freq=)"
),
"Error in feature model"
);
EXPECT_DEATH
(
ExpectRejectedFml
(
"input.token.word(min-freq=10"
),
"Error in feature model"
);
ExpectRejectedFml
(
"input.token.word(min-freq=ten)"
);
ExpectRejectedFml
(
"input.token.word(min_freq=10)"
);
// underscore
}
// Tests that the extractor can be initialized and used to extract feature IDs.
TEST
(
SyntaxNetWordSequenceExtractorTest
,
InitializeAndGetIds
)
{
// Terms are sorted by descending frequency, so this ensures a=0, bc=1, etc.
// Note that "e" is too infrequent, so vocabulary_size=5 from 3 terms plus 2
// special values.
const
string
path
=
WriteTermMap
({{
"a"
,
5
},
{
"bc"
,
3
},
{
"d"
,
2
},
{
"e"
,
1
}});
const
ComponentSpec
component_spec
=
MakeSpec
(
"fixed_feature {} "
"fixed_feature { vocabulary_size:5 fml:'input.token.word(min-freq=2)' }"
,
path
);
const
FixedFeatureChannel
&
channel
=
component_spec
.
fixed_feature
(
1
);
std
::
unique_ptr
<
SequenceExtractor
>
extractor
;
TF_ASSERT_OK
(
SequenceExtractor
::
New
(
"SyntaxNetWordSequenceExtractor"
,
channel
,
component_spec
,
&
extractor
));
const
Sentence
sentence
=
MakeSentence
();
InputBatchCache
input
(
sentence
.
SerializeAsString
());
std
::
vector
<
int32
>
ids
;
TF_ASSERT_OK
(
extractor
->
GetIds
(
&
input
,
&
ids
));
const
std
::
vector
<
int32
>
expected_ids
=
{
0
,
1
,
3
};
EXPECT_EQ
(
ids
,
expected_ids
);
}
// Tests that an empty term map works.
TEST
(
SyntaxNetWordSequenceExtractorTest
,
EmptyTermMap
)
{
const
string
path
=
WriteTermMap
({});
const
ComponentSpec
component_spec
=
MakeSpec
(
"fixed_feature {} "
"fixed_feature { fml:'input.token.word' vocabulary_size:2 }"
,
path
);
const
FixedFeatureChannel
&
channel
=
component_spec
.
fixed_feature
(
1
);
std
::
unique_ptr
<
SequenceExtractor
>
extractor
;
TF_ASSERT_OK
(
SequenceExtractor
::
New
(
"SyntaxNetWordSequenceExtractor"
,
channel
,
component_spec
,
&
extractor
));
const
Sentence
sentence
=
MakeSentence
();
InputBatchCache
input
(
sentence
.
SerializeAsString
());
std
::
vector
<
int32
>
ids
=
{
1
,
2
,
3
,
4
};
// should be overwritten
TF_ASSERT_OK
(
extractor
->
GetIds
(
&
input
,
&
ids
));
const
std
::
vector
<
int32
>
expected_ids
=
{
0
,
0
,
0
};
EXPECT_EQ
(
ids
,
expected_ids
);
}
// Tests that GetIds() fails if the batch is the wrong size.
TEST
(
SyntaxNetWordSequenceExtractorTest
,
WrongBatchSize
)
{
const
string
path
=
WriteTermMap
({});
const
ComponentSpec
component_spec
=
MakeSpec
(
"fixed_feature {} "
"fixed_feature { fml:'input.token.word' vocabulary_size:2 }"
,
path
);
const
FixedFeatureChannel
&
channel
=
component_spec
.
fixed_feature
(
1
);
std
::
unique_ptr
<
SequenceExtractor
>
extractor
;
TF_ASSERT_OK
(
SequenceExtractor
::
New
(
"SyntaxNetWordSequenceExtractor"
,
channel
,
component_spec
,
&
extractor
));
const
Sentence
sentence
=
MakeSentence
();
const
std
::
vector
<
string
>
data
=
{
sentence
.
SerializeAsString
(),
sentence
.
SerializeAsString
()};
InputBatchCache
input
(
data
);
std
::
vector
<
int32
>
ids
;
EXPECT_THAT
(
extractor
->
GetIds
(
&
input
,
&
ids
),
test
::
IsErrorWithSubstr
(
"Non-singleton batch: got 2 elements"
));
}
// Tests that initialization fails if the vocabulary size does not match.
TEST
(
SyntaxNetWordSequenceExtractorTest
,
WrongVocabularySize
)
{
const
string
path
=
WriteTermMap
({});
const
ComponentSpec
component_spec
=
MakeSpec
(
"fixed_feature {} "
"fixed_feature { fml:'input.token.word' vocabulary_size:1000 }"
,
path
);
const
FixedFeatureChannel
&
channel
=
component_spec
.
fixed_feature
(
1
);
std
::
unique_ptr
<
SequenceExtractor
>
extractor
;
EXPECT_THAT
(
SequenceExtractor
::
New
(
"SyntaxNetWordSequenceExtractor"
,
channel
,
component_spec
,
&
extractor
),
test
::
IsErrorWithSubstr
(
"Word vocabulary size mismatch between term "
"map (2) and ComponentSpec (1000)"
));
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/term_map_sequence_extractor.h
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_TERM_MAP_SEQUENCE_EXTRACTOR_H_
#define DRAGNN_RUNTIME_TERM_MAP_SEQUENCE_EXTRACTOR_H_
#include <string>
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/sequence_extractor.h"
#include "dragnn/runtime/term_map_utils.h"
#include "syntaxnet/base.h"
#include "syntaxnet/shared_store.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// Base class for TermFrequencyMap-based sequence feature extractors. Requires
// the component to have a single fixed feature and a TermFrequencyMap resource.
// Templated on a |TermMap| type, which should have a 3-arg constructor similar
// to TermFrequencyMap's.
template
<
class
TermMap
>
class
TermMapSequenceExtractor
:
public
SequenceExtractor
{
public:
// Creates a sequence extractor that will load a term map from the resource
// named |resource_name|.
explicit
TermMapSequenceExtractor
(
const
string
&
resource_name
);
~
TermMapSequenceExtractor
()
override
;
// Returns true if the |channel| of the |component_spec| is compatible with
// this. Subclasses should call this from their Supports().
bool
SupportsTermMap
(
const
FixedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
)
const
;
// Loads a term map from the |channel| of the |component_spec|, applying the
// |min_frequency| and |max_num_terms| when loading the term map. On error,
// returns non-OK. Subclasses should call this from their Initialize().
tensorflow
::
Status
InitializeTermMap
(
const
FixedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
,
int
min_frequency
,
int
max_num_terms
);
protected:
// Returns the current term map. Only valid after InitializeTermMap().
const
TermMap
&
term_map
()
const
{
return
*
term_map_
;
}
private:
// Name of the resouce from which to load a term map.
const
string
resource_name_
;
// Mapping from terms to feature IDs. Owned by SharedStore.
const
TermMap
*
term_map_
=
nullptr
;
};
// Implementation details below.
template
<
class
TermMap
>
TermMapSequenceExtractor
<
TermMap
>::
TermMapSequenceExtractor
(
const
string
&
resource_name
)
:
resource_name_
(
resource_name
)
{}
template
<
class
TermMap
>
TermMapSequenceExtractor
<
TermMap
>::~
TermMapSequenceExtractor
()
{
if
(
!
SharedStore
::
Release
(
term_map_
))
{
LOG
(
ERROR
)
<<
"Failed to release term map for resource "
<<
resource_name_
;
}
}
template
<
class
TermMap
>
bool
TermMapSequenceExtractor
<
TermMap
>::
SupportsTermMap
(
const
FixedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
)
const
{
return
LookupTermMapResourcePath
(
resource_name_
,
component_spec
)
!=
nullptr
&&
channel
.
size
()
==
1
;
}
template
<
class
TermMap
>
tensorflow
::
Status
TermMapSequenceExtractor
<
TermMap
>::
InitializeTermMap
(
const
FixedFeatureChannel
&
channel
,
const
ComponentSpec
&
component_spec
,
int
min_frequency
,
int
max_num_terms
)
{
const
string
*
path
=
LookupTermMapResourcePath
(
resource_name_
,
component_spec
);
if
(
path
==
nullptr
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"No compatible resource named '"
,
resource_name_
,
"' in ComponentSpec: "
,
component_spec
.
ShortDebugString
());
}
term_map_
=
SharedStoreUtils
::
GetWithDefaultName
<
TermMap
>
(
*
path
,
min_frequency
,
max_num_terms
);
return
tensorflow
::
Status
::
OK
();
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_TERM_MAP_SEQUENCE_EXTRACTOR_H_
research/syntaxnet/dragnn/runtime/term_map_sequence_extractor_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/term_map_sequence_extractor.h"
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/test/term_map_helpers.h"
#include "syntaxnet/base.h"
#include "syntaxnet/term_frequency_map.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
constexpr
char
kResourceName
[]
=
"term-map"
;
constexpr
int
kMinFrequency
=
2
;
constexpr
int
kMaxNumTerms
=
0
;
// no limit
// A subclass for tests.
class
BasicTermMapSequenceExtractor
:
public
TermMapSequenceExtractor
<
TermFrequencyMap
>
{
public:
BasicTermMapSequenceExtractor
()
:
TermMapSequenceExtractor
(
kResourceName
)
{}
// Implements SequenceExtractor. These methods are never called, but must be
// defined so we can instantiate the class.
bool
Supports
(
const
FixedFeatureChannel
&
,
const
ComponentSpec
&
)
const
override
{
return
true
;
}
tensorflow
::
Status
Initialize
(
const
FixedFeatureChannel
&
,
const
ComponentSpec
&
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
GetIds
(
InputBatchCache
*
,
std
::
vector
<
int32
>
*
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
// Publicizes the TermFrequencyMap accessor.
using
TermMapSequenceExtractor
::
term_map
;
};
// Returns a FixedFeatureChannel parsed from the |text|.
FixedFeatureChannel
MakeChannel
(
const
string
&
text
)
{
FixedFeatureChannel
channel
;
CHECK
(
TextFormat
::
ParseFromString
(
text
,
&
channel
));
return
channel
;
}
// Returns a ComponentSpec that contains a term map resource pointing at the
// |path|.
ComponentSpec
MakeSpec
(
const
string
&
path
)
{
ComponentSpec
component_spec
;
AddTermMapResource
(
kResourceName
,
path
,
&
component_spec
);
return
component_spec
;
}
// Tests that a term map can be successfully read.
TEST
(
TermMapSequenceExtractorTest
,
NormalOperation
)
{
const
string
path
=
WriteTermMap
({{
"too-infrequent"
,
kMinFrequency
-
1
},
{
"hello"
,
kMinFrequency
},
{
"world"
,
kMinFrequency
+
1
}});
const
FixedFeatureChannel
channel
=
MakeChannel
(
"size:1"
);
const
ComponentSpec
spec
=
MakeSpec
(
path
);
BasicTermMapSequenceExtractor
extractor
;
ASSERT_TRUE
(
extractor
.
SupportsTermMap
(
channel
,
spec
));
TF_ASSERT_OK
(
extractor
.
InitializeTermMap
(
channel
,
spec
,
kMinFrequency
,
kMaxNumTerms
));
// NB: Terms are sorted by frequency.
EXPECT_EQ
(
extractor
.
term_map
().
Size
(),
2
);
EXPECT_EQ
(
extractor
.
term_map
().
LookupIndex
(
"hello"
,
-
1
),
1
);
EXPECT_EQ
(
extractor
.
term_map
().
LookupIndex
(
"world"
,
-
1
),
0
);
EXPECT_EQ
(
extractor
.
term_map
().
LookupIndex
(
"unknown"
,
-
1
),
-
1
);
}
// Tests that SupportsTermMap() requires the fixed feature channel to have
// size 1.
TEST
(
TermMapSequenceExtractorTest
,
FixedFeatureSize
)
{
const
BasicTermMapSequenceExtractor
extractor
;
ASSERT_TRUE
(
extractor
.
SupportsTermMap
(
MakeChannel
(
"size:1"
),
MakeSpec
(
"/dev/null"
)));
EXPECT_FALSE
(
extractor
.
SupportsTermMap
(
MakeChannel
(
"size:0"
),
MakeSpec
(
"/dev/null"
)));
EXPECT_FALSE
(
extractor
.
SupportsTermMap
(
MakeChannel
(
"size:2"
),
MakeSpec
(
"/dev/null"
)));
}
// Tests that SupportsTermMap() requires a resource with the proper name.
TEST
(
TermMapSequenceExtractorTest
,
ResourceName
)
{
const
BasicTermMapSequenceExtractor
extractor
;
const
FixedFeatureChannel
channel
=
MakeChannel
(
"size:1"
);
ComponentSpec
spec
=
MakeSpec
(
"/dev/null"
);
ASSERT_TRUE
(
extractor
.
SupportsTermMap
(
channel
,
spec
));
spec
.
mutable_resource
(
0
)
->
set_name
(
"whatever"
);
EXPECT_FALSE
(
extractor
.
SupportsTermMap
(
channel
,
spec
));
}
// Tests that InitializeTermMap() fails if the term map cannot be found.
TEST
(
TermMapSequenceExtractorTest
,
InitializeWithNoTermMap
)
{
BasicTermMapSequenceExtractor
extractor
;
const
FixedFeatureChannel
channel
;
const
ComponentSpec
spec
;
EXPECT_THAT
(
extractor
.
InitializeTermMap
(
channel
,
spec
,
kMinFrequency
,
kMaxNumTerms
),
test
::
IsErrorWithSubstr
(
"No compatible resource"
));
}
// Tests that InitializeTermMap() requires a proper term map file.
TEST
(
TermMapSequenceExtractorTest
,
InvalidPath
)
{
BasicTermMapSequenceExtractor
extractor
;
const
FixedFeatureChannel
channel
=
MakeChannel
(
"size:1"
);
const
ComponentSpec
spec
=
MakeSpec
(
"/some/bad/path"
);
ASSERT_TRUE
(
extractor
.
SupportsTermMap
(
channel
,
spec
));
EXPECT_DEATH
(
extractor
.
InitializeTermMap
(
channel
,
spec
,
kMinFrequency
,
kMaxNumTerms
)
.
IgnoreError
(),
"/some/bad/path"
);
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/term_map_sequence_predictor.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/term_map_sequence_predictor.h"
#include "dragnn/runtime/term_map_utils.h"
#include "syntaxnet/shared_store.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
TermMapSequencePredictor
::
TermMapSequencePredictor
(
const
string
&
resource_name
)
:
resource_name_
(
resource_name
)
{}
TermMapSequencePredictor
::~
TermMapSequencePredictor
()
{
if
(
!
SharedStore
::
Release
(
term_map_
))
{
LOG
(
ERROR
)
<<
"Failed to release term map for resource "
<<
resource_name_
;
}
}
bool
TermMapSequencePredictor
::
SupportsTermMap
(
const
ComponentSpec
&
component_spec
)
const
{
return
LookupTermMapResourcePath
(
resource_name_
,
component_spec
)
!=
nullptr
;
}
tensorflow
::
Status
TermMapSequencePredictor
::
InitializeTermMap
(
const
ComponentSpec
&
component_spec
,
int
min_frequency
,
int
max_num_terms
)
{
const
string
*
path
=
LookupTermMapResourcePath
(
resource_name_
,
component_spec
);
if
(
path
==
nullptr
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"No compatible resource named '"
,
resource_name_
,
"' in ComponentSpec: "
,
component_spec
.
ShortDebugString
());
}
term_map_
=
SharedStoreUtils
::
GetWithDefaultName
<
TermFrequencyMap
>
(
*
path
,
min_frequency
,
max_num_terms
);
return
tensorflow
::
Status
::
OK
();
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/term_map_sequence_predictor.h
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_TERM_MAP_SEQUENCE_PREDICTOR_H_
#define DRAGNN_RUNTIME_TERM_MAP_SEQUENCE_PREDICTOR_H_
#include <string>
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/sequence_predictor.h"
#include "syntaxnet/base.h"
#include "syntaxnet/term_frequency_map.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// Base class for predictors whose output label set is defined by a term map.
// Requires the component to have a TermFrequencyMap resource.
class
TermMapSequencePredictor
:
public
SequencePredictor
{
public:
// Creates a sequence predictor that will load a term map from the resource
// named |resource_name|.
explicit
TermMapSequencePredictor
(
const
string
&
resource_name
);
~
TermMapSequencePredictor
()
override
;
// Returns true if the |component_spec| is compatible with this. Subclasses
// should call this from their Supports().
bool
SupportsTermMap
(
const
ComponentSpec
&
component_spec
)
const
;
// Loads a term map from the |component_spec|, applying the |min_frequency|
// and |max_num_terms| when loading the term map. On error, returns non-OK.
// Subclasses should call this from their Initialize().
tensorflow
::
Status
InitializeTermMap
(
const
ComponentSpec
&
component_spec
,
int
min_frequency
,
int
max_num_terms
);
protected:
// Returns the current term map. Only valid after InitializeTermMap().
const
TermFrequencyMap
&
term_map
()
const
{
return
*
term_map_
;
}
private:
// Name of the resouce from which to load a term map.
const
string
resource_name_
;
// Mapping from strings to feature IDs. Owned by SharedStore.
const
TermFrequencyMap
*
term_map_
=
nullptr
;
};
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_TERM_MAP_SEQUENCE_PREDICTOR_H_
research/syntaxnet/dragnn/runtime/term_map_sequence_predictor_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/term_map_sequence_predictor.h"
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/test/term_map_helpers.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
constexpr
char
kResourceName
[]
=
"term-map"
;
constexpr
int
kMinFrequency
=
2
;
constexpr
int
kMaxNumTerms
=
0
;
// no limit
// A subclass for tests.
class
BasicTermMapSequencePredictor
:
public
TermMapSequencePredictor
{
public:
BasicTermMapSequencePredictor
()
:
TermMapSequencePredictor
(
kResourceName
)
{}
// Implements SequencePredictor. These methods are never called, but must be
// defined so we can instantiate the class.
bool
Supports
(
const
ComponentSpec
&
)
const
override
{
return
true
;
}
tensorflow
::
Status
Initialize
(
const
ComponentSpec
&
)
override
{
return
tensorflow
::
Status
::
OK
();
}
tensorflow
::
Status
Predict
(
Matrix
<
float
>
,
InputBatchCache
*
)
const
override
{
return
tensorflow
::
Status
::
OK
();
}
// Publicizes the TermFrequencyMap accessor.
using
TermMapSequencePredictor
::
term_map
;
};
// Returns a ComponentSpec that contains a term map resource pointing at the
// |path|.
ComponentSpec
MakeSpec
(
const
string
&
path
)
{
ComponentSpec
component_spec
;
AddTermMapResource
(
kResourceName
,
path
,
&
component_spec
);
return
component_spec
;
}
// Tests that a term map can be successfully read.
TEST
(
TermMapSequencePredictorTest
,
NormalOperation
)
{
const
string
path
=
WriteTermMap
({{
"too-infrequent"
,
kMinFrequency
-
1
},
{
"hello"
,
kMinFrequency
},
{
"world"
,
kMinFrequency
+
1
}});
const
ComponentSpec
spec
=
MakeSpec
(
path
);
BasicTermMapSequencePredictor
predictor
;
ASSERT_TRUE
(
predictor
.
SupportsTermMap
(
spec
));
TF_ASSERT_OK
(
predictor
.
InitializeTermMap
(
spec
,
kMinFrequency
,
kMaxNumTerms
));
// NB: Terms are sorted by frequency.
EXPECT_EQ
(
predictor
.
term_map
().
Size
(),
2
);
EXPECT_EQ
(
predictor
.
term_map
().
LookupIndex
(
"hello"
,
-
1
),
1
);
EXPECT_EQ
(
predictor
.
term_map
().
LookupIndex
(
"world"
,
-
1
),
0
);
EXPECT_EQ
(
predictor
.
term_map
().
LookupIndex
(
"unknown"
,
-
1
),
-
1
);
}
// Tests that SupportsTermMap() requires a resource with the proper name.
TEST
(
TermMapSequencePredictorTest
,
ResourceName
)
{
const
BasicTermMapSequencePredictor
predictor
;
ComponentSpec
spec
=
MakeSpec
(
"/dev/null"
);
ASSERT_TRUE
(
predictor
.
SupportsTermMap
(
spec
));
spec
.
mutable_resource
(
0
)
->
set_name
(
"whatever"
);
EXPECT_FALSE
(
predictor
.
SupportsTermMap
(
spec
));
}
// Tests that InitializeTermMap() fails if the term map cannot be found.
TEST
(
TermMapSequencePredictorTest
,
InitializeWithNoTermMap
)
{
BasicTermMapSequencePredictor
predictor
;
const
ComponentSpec
spec
;
EXPECT_THAT
(
predictor
.
InitializeTermMap
(
spec
,
kMinFrequency
,
kMaxNumTerms
),
test
::
IsErrorWithSubstr
(
"No compatible resource"
));
}
// Tests that InitializeTermMap() requires a proper term map file.
TEST
(
TermMapSequencePredictorTest
,
InvalidPath
)
{
BasicTermMapSequencePredictor
predictor
;
const
ComponentSpec
spec
=
MakeSpec
(
"/some/bad/path"
);
ASSERT_TRUE
(
predictor
.
SupportsTermMap
(
spec
));
EXPECT_DEATH
(
predictor
.
InitializeTermMap
(
spec
,
kMinFrequency
,
kMaxNumTerms
)
.
IgnoreError
(),
"/some/bad/path"
);
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/term_map_utils.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/term_map_utils.h"
#include "dragnn/runtime/fml_parsing.h"
#include "syntaxnet/feature_extractor.pb.h"
#include "tensorflow/core/lib/core/errors.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Attributes for extracting term map feature options
struct
TermMapAttributes
:
public
FeatureFunctionAttributes
{
// Minimum frequency for included terms.
Optional
<
int32
>
min_frequency
{
"min-freq"
,
0
,
this
};
// Maximum number of terms to include.
Optional
<
int32
>
max_num_terms
{
"max-num-terms"
,
0
,
this
};
};
// Returns true if the |record_format| is compatible with a TermFrequencyMap.
bool
CompatibleRecordFormat
(
const
string
&
record_format
)
{
return
record_format
.
empty
()
||
record_format
==
"TermFrequencyMap"
;
}
}
// namespace
const
string
*
LookupTermMapResourcePath
(
const
string
&
resource_name
,
const
ComponentSpec
&
component_spec
)
{
for
(
const
Resource
&
resource
:
component_spec
.
resource
())
{
if
(
resource
.
name
()
!=
resource_name
)
continue
;
if
(
resource
.
part_size
()
!=
1
)
continue
;
const
Part
&
part
=
resource
.
part
(
0
);
if
(
part
.
file_format
()
!=
"text"
)
continue
;
if
(
!
CompatibleRecordFormat
(
part
.
record_format
()))
continue
;
return
&
part
.
file_pattern
();
}
return
nullptr
;
}
tensorflow
::
Status
ParseTermMapFml
(
const
string
&
fml
,
const
std
::
vector
<
string
>
&
types
,
int
*
min_frequency
,
int
*
max_num_terms
)
{
FeatureFunctionDescriptor
function
;
TF_RETURN_IF_ERROR
(
ParseFeatureChainFml
(
fml
,
types
,
&
function
));
if
(
function
.
argument
()
!=
0
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"TermFrequencyMap-based feature should have no argument: "
,
fml
);
}
TermMapAttributes
attributes
;
TF_RETURN_IF_ERROR
(
attributes
.
Reset
(
function
));
// Success; make modifications.
*
min_frequency
=
attributes
.
min_frequency
();
*
max_num_terms
=
attributes
.
max_num_terms
();
return
tensorflow
::
Status
::
OK
();
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/term_map_utils.h
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_TERM_MAP_UTILS_H_
#define DRAGNN_RUNTIME_TERM_MAP_UTILS_H_
#include <string>
#include <vector>
#include "dragnn/protos/spec.pb.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// Returns the path to the TermFrequencyMap resource named |resource_name| in
// the |component_spec|, or null if not found.
const
string
*
LookupTermMapResourcePath
(
const
string
&
resource_name
,
const
ComponentSpec
&
component_spec
);
// Parses the |fml| as a chain of |types| ending in a TermFrequencyMap-based
// feature with "min-freq" and "max-num-terms" options. Sets |min_frequency|
// and |max_num_terms| to the option values. On error, returns non-OK and
// modifies nothing.
tensorflow
::
Status
ParseTermMapFml
(
const
string
&
fml
,
const
std
::
vector
<
string
>
&
types
,
int
*
min_frequency
,
int
*
max_num_terms
);
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_TERM_MAP_UTILS_H_
research/syntaxnet/dragnn/runtime/term_map_utils_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/term_map_utils.h"
#include <string>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/test/term_map_helpers.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
constexpr
char
kResourceName
[]
=
"term-map"
;
constexpr
char
kResourcePath
[]
=
"/path/to/term-map"
;
// Returns a ComponentSpec with a term map resource named |kResourceName| that
// points at |kResourcePath|.
ComponentSpec
MakeSpec
()
{
ComponentSpec
spec
;
AddTermMapResource
(
kResourceName
,
kResourcePath
,
&
spec
);
return
spec
;
}
// Tests that a term map resource can be successfully read.
TEST
(
LookupTermMapResourcePathTest
,
Success
)
{
const
ComponentSpec
spec
=
MakeSpec
();
const
string
*
path
=
LookupTermMapResourcePath
(
kResourceName
,
spec
);
ASSERT_NE
(
path
,
nullptr
);
EXPECT_EQ
(
*
path
,
kResourcePath
);
}
// Tests that the returned path is null for an empty spec.
TEST
(
LookupTermMapResourcePathTest
,
EmptySpec
)
{
const
ComponentSpec
spec
;
EXPECT_EQ
(
LookupTermMapResourcePath
(
kResourceName
,
spec
),
nullptr
);
}
// Tests that the returned path is null for the wrong resource name.
TEST
(
LookupTermMapResourcePathTest
,
WrongName
)
{
ComponentSpec
spec
=
MakeSpec
();
spec
.
mutable_resource
(
0
)
->
set_name
(
"bad"
);
EXPECT_EQ
(
LookupTermMapResourcePath
(
kResourceName
,
spec
),
nullptr
);
}
// Tests that the returned path is null for the wrong number of parts.
TEST
(
LookupTermMapResourcePathTest
,
WrongNumberOfParts
)
{
ComponentSpec
spec
=
MakeSpec
();
spec
.
mutable_resource
(
0
)
->
clear_part
();
EXPECT_EQ
(
LookupTermMapResourcePath
(
kResourceName
,
spec
),
nullptr
);
spec
.
mutable_resource
(
0
)
->
add_part
();
spec
.
mutable_resource
(
0
)
->
add_part
();
EXPECT_EQ
(
LookupTermMapResourcePath
(
kResourceName
,
spec
),
nullptr
);
}
// Tests that the returned path is null for the wrong file format.
TEST
(
LookupTermMapResourcePathTest
,
WrongFileFormat
)
{
ComponentSpec
spec
=
MakeSpec
();
spec
.
mutable_resource
(
0
)
->
mutable_part
(
0
)
->
set_file_format
(
"bad"
);
EXPECT_EQ
(
LookupTermMapResourcePath
(
kResourceName
,
spec
),
nullptr
);
}
// Tests that the returned path is null for the wrong record format.
TEST
(
LookupTermMapResourcePathTest
,
WrongRecordFormat
)
{
ComponentSpec
spec
=
MakeSpec
();
spec
.
mutable_resource
(
0
)
->
mutable_part
(
0
)
->
set_record_format
(
"bad"
);
EXPECT_EQ
(
LookupTermMapResourcePath
(
kResourceName
,
spec
),
nullptr
);
}
// Tests that alternate record formats are accepted.
TEST
(
LookupTermMapResourcePathTest
,
SuccessWithAlternateRecordFormat
)
{
ComponentSpec
spec
=
MakeSpec
();
spec
.
mutable_resource
(
0
)
->
mutable_part
(
0
)
->
set_record_format
(
"TermFrequencyMap"
);
const
string
*
path
=
LookupTermMapResourcePath
(
kResourceName
,
spec
);
ASSERT_NE
(
path
,
nullptr
);
EXPECT_EQ
(
*
path
,
kResourcePath
);
}
// Tests that ParseTermMapFml() correctly parses term map feature options.
TEST
(
ParseTermMapFmlTest
,
Success
)
{
int
min_frequency
=
-
1
;
int
max_num_terms
=
-
1
;
TF_ASSERT_OK
(
ParseTermMapFml
(
"path.to.foo"
,
{
"path"
,
"to"
,
"foo"
},
&
min_frequency
,
&
max_num_terms
));
EXPECT_EQ
(
min_frequency
,
0
);
EXPECT_EQ
(
max_num_terms
,
0
);
TF_ASSERT_OK
(
ParseTermMapFml
(
"path.to.foo(min-freq=5)"
,
{
"path"
,
"to"
,
"foo"
},
&
min_frequency
,
&
max_num_terms
));
EXPECT_EQ
(
min_frequency
,
5
);
EXPECT_EQ
(
max_num_terms
,
0
);
TF_ASSERT_OK
(
ParseTermMapFml
(
"path.to.foo(max-num-terms=1000)"
,
{
"path"
,
"to"
,
"foo"
},
&
min_frequency
,
&
max_num_terms
));
EXPECT_EQ
(
min_frequency
,
0
);
EXPECT_EQ
(
max_num_terms
,
1000
);
TF_ASSERT_OK
(
ParseTermMapFml
(
"path.to.foo(min-freq=12,max-num-terms=3456)"
,
{
"path"
,
"to"
,
"foo"
},
&
min_frequency
,
&
max_num_terms
));
EXPECT_EQ
(
min_frequency
,
12
);
EXPECT_EQ
(
max_num_terms
,
3456
);
}
// Tests that ParseTermMapFml() tolerates a zero argument.
TEST
(
ParseTermMapFmlTest
,
SuccessWithZeroArgument
)
{
int
min_frequency
=
-
1
;
int
max_num_terms
=
-
1
;
TF_ASSERT_OK
(
ParseTermMapFml
(
"path.to.foo(0)"
,
{
"path"
,
"to"
,
"foo"
},
&
min_frequency
,
&
max_num_terms
));
EXPECT_EQ
(
min_frequency
,
0
);
EXPECT_EQ
(
max_num_terms
,
0
);
TF_ASSERT_OK
(
ParseTermMapFml
(
"path.to.foo(0,min-freq=5)"
,
{
"path"
,
"to"
,
"foo"
},
&
min_frequency
,
&
max_num_terms
));
EXPECT_EQ
(
min_frequency
,
5
);
EXPECT_EQ
(
max_num_terms
,
0
);
TF_ASSERT_OK
(
ParseTermMapFml
(
"path.to.foo(0,max-num-terms=1000)"
,
{
"path"
,
"to"
,
"foo"
},
&
min_frequency
,
&
max_num_terms
));
EXPECT_EQ
(
min_frequency
,
0
);
EXPECT_EQ
(
max_num_terms
,
1000
);
TF_ASSERT_OK
(
ParseTermMapFml
(
"path.to.foo(0,min-freq=12,max-num-terms=3456)"
,
{
"path"
,
"to"
,
"foo"
},
&
min_frequency
,
&
max_num_terms
));
EXPECT_EQ
(
min_frequency
,
12
);
EXPECT_EQ
(
max_num_terms
,
3456
);
}
// Tests that ParseTermMapFml() fails on a non-zero argument.
TEST
(
ParseTermMapFmlTest
,
NonZeroArgument
)
{
int
min_frequency
=
-
1
;
int
max_num_terms
=
-
1
;
EXPECT_THAT
(
ParseTermMapFml
(
"path.to.foo(1)"
,
{
"path"
,
"to"
,
"foo"
},
&
min_frequency
,
&
max_num_terms
),
test
::
IsErrorWithSubstr
(
"TermFrequencyMap-based feature should have no argument"
));
EXPECT_EQ
(
min_frequency
,
-
1
);
EXPECT_EQ
(
max_num_terms
,
-
1
);
}
// Tests that ParseTermMapFml() fails on an unknown feature option.
TEST
(
ParseTermMapFmlTest
,
UnknownOption
)
{
int
min_frequency
=
-
1
;
int
max_num_terms
=
-
1
;
EXPECT_THAT
(
ParseTermMapFml
(
"path.to.foo(unknown=1)"
,
{
"path"
,
"to"
,
"foo"
},
&
min_frequency
,
&
max_num_terms
),
test
::
IsErrorWithSubstr
(
"Unknown attribute"
));
EXPECT_EQ
(
min_frequency
,
-
1
);
EXPECT_EQ
(
max_num_terms
,
-
1
);
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/test/BUILD
deleted
100644 → 0
View file @
a4bb31d0
package
(
default_visibility
=
[
"//visibility:public"
],
)
cc_library
(
name
=
"helpers"
,
testonly
=
1
,
srcs
=
[
"helpers.cc"
],
hdrs
=
[
"helpers.h"
],
deps
=
[
"//dragnn/runtime:alignment"
,
"//dragnn/runtime/math:avx_vector_array"
,
"//dragnn/runtime/math:sgemvv"
,
"//dragnn/runtime/math:transformations"
,
"//dragnn/runtime/math:types"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
)
cc_test
(
name
=
"helpers_test"
,
size
=
"small"
,
srcs
=
[
"helpers_test.cc"
],
deps
=
[
":helpers"
,
"//dragnn/runtime:alignment"
,
"//dragnn/runtime/math:types"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"fake_variable_store"
,
testonly
=
1
,
srcs
=
[
"fake_variable_store.cc"
],
hdrs
=
[
"fake_variable_store.h"
],
deps
=
[
":helpers"
,
"//dragnn/protos:runtime_proto_cc"
,
"//dragnn/runtime:alignment"
,
"//dragnn/runtime:variable_store"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
)
cc_test
(
name
=
"fake_variable_store_test"
,
size
=
"small"
,
srcs
=
[
"fake_variable_store_test.cc"
],
deps
=
[
":fake_variable_store"
,
"//dragnn/core/test:generic"
,
"//dragnn/runtime:alignment"
,
"//dragnn/runtime/math:types"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"network_test_base"
,
testonly
=
1
,
srcs
=
[
"network_test_base.cc"
],
hdrs
=
[
"network_test_base.h"
],
deps
=
[
":fake_variable_store"
,
"//dragnn/core/test:mock_compute_session"
,
"//dragnn/protos:data_proto_cc"
,
"//dragnn/runtime:extensions"
,
"//dragnn/runtime:flexible_matrix_kernel"
,
"//dragnn/runtime:network_states"
,
"//dragnn/runtime:session_state"
,
"//dragnn/runtime/math:types"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"term_map_helpers"
,
testonly
=
1
,
srcs
=
[
"term_map_helpers.cc"
],
hdrs
=
[
"term_map_helpers.h"
],
deps
=
[
"//dragnn/protos:spec_proto_cc"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_test
(
name
=
"term_map_helpers_test"
,
size
=
"small"
,
srcs
=
[
"term_map_helpers_test.cc"
],
deps
=
[
":term_map_helpers"
,
"//dragnn/core/test:generic"
,
"//dragnn/protos:spec_proto_cc"
,
"//syntaxnet:base"
,
"//syntaxnet:term_frequency_map"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
research/syntaxnet/dragnn/runtime/test/fake_variable_store.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/test/fake_variable_store.h"
#include <string.h>
#include <utility>
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
void
FakeVariableStore
::
AddOrDie
(
const
string
&
name
,
const
std
::
vector
<
std
::
vector
<
float
>>
&
data
,
VariableSpec
::
Format
format
)
{
CHECK
(
variables_
[
name
].
empty
())
<<
"Adding duplicate variable: "
<<
name
;
FormatMap
formats
;
// Add a flattened version.
std
::
vector
<
std
::
vector
<
float
>>
flat
(
1
);
for
(
const
auto
&
row
:
data
)
{
for
(
const
float
value
:
row
)
flat
[
0
].
push_back
(
value
);
}
formats
[
VariableSpec
::
FORMAT_FLAT
]
=
Variable
(
flat
);
// Add the |data| in its natural row-major format.
formats
[
VariableSpec
::
FORMAT_ROW_MAJOR_MATRIX
]
=
Variable
(
data
);
// Add the |data| as a trivial blocked matrix with one block---i.e., block
// size equal to the number of columns. Conveniently, this matrix has the
// same underlying data layout as a plain matrix.
formats
[
VariableSpec
::
FORMAT_COLUMN_BLOCKED_ROW_MAJOR_MATRIX
]
=
Variable
(
data
);
// If |format| is FORMAT_UNKNOWN, keep all formats. Otherwise, only keep the
// specified format.
if
(
format
==
VariableSpec
::
FORMAT_UNKNOWN
)
{
variables_
[
name
]
=
std
::
move
(
formats
);
}
else
{
variables_
[
name
][
format
]
=
std
::
move
(
formats
[
format
]);
}
}
void
FakeVariableStore
::
SetBlockedDimensionOverride
(
const
string
&
name
,
const
std
::
vector
<
size_t
>
&
dimensions
)
{
override_blocked_dimensions_
[
name
]
=
dimensions
;
}
tensorflow
::
Status
FakeVariableStore
::
Lookup
(
const
string
&
name
,
VariableSpec
::
Format
format
,
std
::
vector
<
size_t
>
*
dimensions
,
AlignedArea
*
area
)
{
const
auto
it
=
variables_
.
find
(
name
);
if
(
it
==
variables_
.
end
())
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Unknown variable: "
,
name
);
}
FormatMap
&
formats
=
it
->
second
;
if
(
formats
.
find
(
format
)
==
formats
.
end
())
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Unknown variable: "
,
name
);
}
Variable
&
variable
=
formats
.
at
(
format
);
dimensions
->
clear
();
switch
(
format
)
{
case
VariableSpec
::
FORMAT_UNKNOWN
:
// This case should not happen because the |formats| mapping never has
// FORMAT_UNKNOWN as a key.
LOG
(
FATAL
)
<<
"Tried to get a variable with FORMAT_UNKNOWN"
;
case
VariableSpec
::
FORMAT_FLAT
:
*
dimensions
=
{
variable
->
num_columns
()};
break
;
case
VariableSpec
::
FORMAT_ROW_MAJOR_MATRIX
:
*
dimensions
=
{
variable
->
num_rows
(),
variable
->
num_columns
()};
break
;
case
VariableSpec
::
FORMAT_COLUMN_BLOCKED_ROW_MAJOR_MATRIX
:
if
(
override_blocked_dimensions_
.
find
(
name
)
!=
override_blocked_dimensions_
.
end
())
{
*
dimensions
=
override_blocked_dimensions_
[
name
];
}
else
{
*
dimensions
=
{
variable
->
num_rows
(),
variable
->
num_columns
(),
variable
->
num_columns
()};
// = block_size
}
break
;
}
*
area
=
variable
.
area
();
return
tensorflow
::
Status
::
OK
();
}
// Executes cleanup functions (see `cleanup_` comment).
SimpleFakeVariableStore
::~
SimpleFakeVariableStore
()
{
for
(
const
auto
&
fcn
:
cleanup_
)
{
fcn
();
}
}
tensorflow
::
Status
SimpleFakeVariableStore
::
Lookup
(
const
string
&
name
,
VariableSpec
::
Format
format
,
std
::
vector
<
size_t
>
*
dimensions
,
AlignedArea
*
area
)
{
// Test should call MockLookup() first.
CHECK
(
dimensions_to_return_
!=
nullptr
);
CHECK
(
area_to_return_
!=
nullptr
);
*
dimensions
=
*
dimensions_to_return_
;
*
area
=
*
area_to_return_
;
dimensions_to_return_
=
nullptr
;
area_to_return_
=
nullptr
;
return
tensorflow
::
Status
::
OK
();
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/test/fake_variable_store.h
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_TEST_FAKE_VARIABLE_STORE_H_
#define DRAGNN_RUNTIME_TEST_FAKE_VARIABLE_STORE_H_
#include <map>
#include <string>
#include <vector>
#include "dragnn/protos/runtime.pb.h"
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/test/helpers.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// A fake variable store with user-specified contents.
class
FakeVariableStore
:
public
VariableStore
{
public:
// Creates an empty store.
FakeVariableStore
()
=
default
;
// Adds the |data| to this as a variable with the |name| and |format|. If the
// |format| is FORMAT_UNKNOWN, adds the data in all formats. On error, aborts
// the program.
void
AddOrDie
(
const
string
&
name
,
const
std
::
vector
<
std
::
vector
<
float
>>
&
data
,
VariableSpec
::
Format
format
=
VariableSpec
::
FORMAT_UNKNOWN
);
// Overrides the default behavior of assuming that there is one block along
// the major axis of the matrix.
void
SetBlockedDimensionOverride
(
const
string
&
name
,
const
std
::
vector
<
size_t
>
&
dimensions
);
// Implements VariableStore.
using
VariableStore
::
Lookup
;
// import Lookup<T>() convenience methods
tensorflow
::
Status
Lookup
(
const
string
&
name
,
VariableSpec
::
Format
format
,
std
::
vector
<
size_t
>
*
dimensions
,
AlignedArea
*
area
)
override
;
tensorflow
::
Status
Close
()
override
{
return
tensorflow
::
Status
::
OK
();
}
private:
using
Variable
=
UniqueMatrix
<
float
>
;
using
FormatMap
=
std
::
map
<
VariableSpec
::
Format
,
Variable
>
;
// Mappings from variable name to format to contents.
std
::
map
<
string
,
FormatMap
>
variables_
;
// Overrides blocked dimensions.
std
::
map
<
string
,
std
::
vector
<
size_t
>>
override_blocked_dimensions_
;
};
// Syntactic sugar for replicating data to SimpleFakeVariableStore::MockLookup.
template
<
typename
T
>
std
::
vector
<
std
::
vector
<
T
>>
ReplicateRows
(
std
::
vector
<
T
>
values
,
int
times
)
{
return
std
::
vector
<
std
::
vector
<
T
>>
(
times
,
values
);
}
// Simpler fake variable store, where the test just sets up the next value to be
// returned.
class
SimpleFakeVariableStore
:
public
VariableStore
{
public:
// Executes cleanup functions (see `cleanup_` comment).
~
SimpleFakeVariableStore
()
override
;
// Sets values which store().Lookup() will return.
template
<
typename
T
>
void
MockLookup
(
const
std
::
vector
<
size_t
>
&
dimensions
,
const
std
::
vector
<
std
::
vector
<
T
>>
&
area_values
)
{
UniqueMatrix
<
T
>
*
matrix
=
new
UniqueMatrix
<
T
>
(
area_values
);
cleanup_
.
push_back
([
matrix
]()
{
delete
matrix
;
});
dimensions_to_return_
.
reset
(
new
std
::
vector
<
size_t
>
(
dimensions
));
area_to_return_
.
reset
(
new
AlignedArea
(
matrix
->
area
()));
}
using
VariableStore
::
Lookup
;
// import Lookup<T>() convenience methods
tensorflow
::
Status
Lookup
(
const
string
&
name
,
VariableSpec
::
Format
format
,
std
::
vector
<
size_t
>
*
dimensions
,
AlignedArea
*
area
)
override
;
tensorflow
::
Status
Close
()
override
{
return
tensorflow
::
Status
::
OK
();
}
private:
std
::
unique_ptr
<
std
::
vector
<
size_t
>>
dimensions_to_return_
=
nullptr
;
std
::
unique_ptr
<
AlignedArea
>
area_to_return_
=
nullptr
;
// Functions which will delete memory storing mocked arrays. We want to keep
// the memory accessible until the end of the test. We also can't keep an
// array of objects to delete, since they are of different types.
std
::
vector
<
std
::
function
<
void
()
>>
cleanup_
;
};
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_TEST_FAKE_VARIABLE_STORE_H_
research/syntaxnet/dragnn/runtime/test/fake_variable_store_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/test/fake_variable_store.h"
#include <stddef.h>
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/test/generic.h"
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/math/types.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Returns a data matrix that has no alignment padding. This is required for
// BlockedMatrix, which does not tolerate alignment padding. The contents of
// the returned matrix are [0.0, 1.0, 2.0, ...] in the natural order.
std
::
vector
<
std
::
vector
<
float
>>
MakeBlockedData
()
{
const
size_t
kNumRows
=
18
;
const
size_t
kNumColumns
=
internal
::
kAlignmentBytes
/
sizeof
(
float
);
std
::
vector
<
std
::
vector
<
float
>>
data
(
kNumRows
);
float
counter
=
0.0
;
for
(
std
::
vector
<
float
>
&
row
:
data
)
{
row
.
resize
(
kNumColumns
);
for
(
float
&
value
:
row
)
value
=
counter
++
;
}
return
data
;
}
// Tests that Lookup*() behaves properly w.r.t. AddOrDie().
TEST
(
FakeVariableStoreTest
,
Lookup
)
{
FakeVariableStore
store
;
AlignedView
view
;
Vector
<
float
>
vector
;
Matrix
<
float
>
matrix
;
BlockedMatrix
<
float
>
blocked_matrix
;
// Fail to look up an unknown name.
EXPECT_THAT
(
store
.
Lookup
(
"foo"
,
&
vector
),
test
::
IsErrorWithSubstr
(
"Unknown variable"
));
EXPECT_TRUE
(
view
.
empty
());
// not modified
// Add some data and try looking it up.
store
.
AddOrDie
(
"foo"
,
{{
1.0
,
2.0
,
3.0
}});
TF_EXPECT_OK
(
store
.
Lookup
(
"foo"
,
&
vector
));
ASSERT_EQ
(
vector
.
size
(),
3
);
EXPECT_EQ
(
vector
[
0
],
1.0
);
EXPECT_EQ
(
vector
[
1
],
2.0
);
EXPECT_EQ
(
vector
[
2
],
3.0
);
TF_EXPECT_OK
(
store
.
Lookup
(
"foo"
,
&
matrix
));
ASSERT_EQ
(
matrix
.
num_rows
(),
1
);
ASSERT_EQ
(
matrix
.
num_columns
(),
3
);
EXPECT_EQ
(
matrix
.
row
(
0
)[
0
],
1.0
);
EXPECT_EQ
(
matrix
.
row
(
0
)[
1
],
2.0
);
EXPECT_EQ
(
matrix
.
row
(
0
)[
2
],
3.0
);
// Try a funny name.
store
.
AddOrDie
(
""
,
{{
5.0
,
7.0
},
{
11.0
,
13.0
}});
TF_EXPECT_OK
(
store
.
Lookup
(
""
,
&
vector
));
ASSERT_EQ
(
vector
.
size
(),
4
);
EXPECT_EQ
(
vector
[
0
],
5.0
);
EXPECT_EQ
(
vector
[
1
],
7.0
);
EXPECT_EQ
(
vector
[
2
],
11.0
);
EXPECT_EQ
(
vector
[
3
],
13.0
);
TF_EXPECT_OK
(
store
.
Lookup
(
""
,
&
matrix
));
ASSERT_EQ
(
matrix
.
num_rows
(),
2
);
ASSERT_EQ
(
matrix
.
num_columns
(),
2
);
EXPECT_EQ
(
matrix
.
row
(
0
)[
0
],
5.0
);
EXPECT_EQ
(
matrix
.
row
(
0
)[
1
],
7.0
);
EXPECT_EQ
(
matrix
.
row
(
1
)[
0
],
11.0
);
EXPECT_EQ
(
matrix
.
row
(
1
)[
1
],
13.0
);
// Try blocked matrices. These must not have alignment padding.
const
auto
blocked_data
=
MakeBlockedData
();
store
.
AddOrDie
(
"blocked"
,
blocked_data
);
TF_ASSERT_OK
(
store
.
Lookup
(
"blocked"
,
&
blocked_matrix
));
ASSERT_EQ
(
blocked_matrix
.
num_rows
(),
blocked_data
.
size
());
ASSERT_EQ
(
blocked_matrix
.
num_columns
(),
blocked_data
[
0
].
size
());
ASSERT_EQ
(
blocked_matrix
.
block_size
(),
blocked_data
[
0
].
size
());
for
(
size_t
vector
=
0
;
vector
<
blocked_matrix
.
num_vectors
();
++
vector
)
{
for
(
size_t
i
=
0
;
i
<
blocked_matrix
.
block_size
();
++
i
)
{
EXPECT_EQ
(
blocked_matrix
.
vector
(
vector
)[
i
],
vector
*
blocked_matrix
.
block_size
()
+
i
);
}
}
// Check that overriding dimensions is OK. Instead of a matrix that has every
// row as a block, every row is now has two blocks, so there are half as many
// rows and each row (number of columns) is twice as long.
const
size_t
kNumColumns
=
internal
::
kAlignmentBytes
/
sizeof
(
float
);
store
.
SetBlockedDimensionOverride
(
"blocked"
,
{
9
,
2
*
kNumColumns
,
kNumColumns
});
TF_ASSERT_OK
(
store
.
Lookup
(
"blocked"
,
&
blocked_matrix
));
ASSERT_EQ
(
blocked_matrix
.
num_rows
(),
blocked_data
.
size
()
/
2
);
ASSERT_EQ
(
blocked_matrix
.
num_columns
(),
2
*
blocked_data
[
0
].
size
());
ASSERT_EQ
(
blocked_matrix
.
block_size
(),
blocked_data
[
0
].
size
());
}
// Tests that the fake variable never contains variables with unknown format.
TEST
(
FakeVariableStoreTest
,
NeverContainsUnknownFormat
)
{
FakeVariableStore
store
;
store
.
AddOrDie
(
"foo"
,
{{
0.0
}});
std
::
vector
<
size_t
>
dimensions
;
AlignedArea
area
;
EXPECT_THAT
(
store
.
Lookup
(
"foo"
,
VariableSpec
::
FORMAT_UNKNOWN
,
&
dimensions
,
&
area
),
test
::
IsErrorWithSubstr
(
"Unknown variable"
));
}
// Tests that the fake variable store can create a variable that only appears in
// one format.
TEST
(
FakeVariableStoreTest
,
AddWithSpecificFormat
)
{
const
auto
data
=
MakeBlockedData
();
FakeVariableStore
store
;
store
.
AddOrDie
(
"flat"
,
data
,
VariableSpec
::
FORMAT_FLAT
);
store
.
AddOrDie
(
"matrix"
,
data
,
VariableSpec
::
FORMAT_ROW_MAJOR_MATRIX
);
store
.
AddOrDie
(
"blocked"
,
data
,
VariableSpec
::
FORMAT_COLUMN_BLOCKED_ROW_MAJOR_MATRIX
);
// Vector lookups should only work for "flat".
Vector
<
float
>
vector
;
TF_ASSERT_OK
(
store
.
Lookup
(
"flat"
,
&
vector
));
EXPECT_THAT
(
store
.
Lookup
(
"matrix"
,
&
vector
),
test
::
IsErrorWithSubstr
(
"Unknown variable"
));
EXPECT_THAT
(
store
.
Lookup
(
"blocked"
,
&
vector
),
test
::
IsErrorWithSubstr
(
"Unknown variable"
));
// Matrix lookups should only work for "matrix".
Matrix
<
float
>
matrix
;
EXPECT_THAT
(
store
.
Lookup
(
"flat"
,
&
matrix
),
test
::
IsErrorWithSubstr
(
"Unknown variable"
));
TF_ASSERT_OK
(
store
.
Lookup
(
"matrix"
,
&
matrix
));
EXPECT_THAT
(
store
.
Lookup
(
"blocked"
,
&
matrix
),
test
::
IsErrorWithSubstr
(
"Unknown variable"
));
// Blocked matrix lookups should only work for "blocked".
BlockedMatrix
<
float
>
blocked_matrix
;
EXPECT_THAT
(
store
.
Lookup
(
"flat"
,
&
blocked_matrix
),
test
::
IsErrorWithSubstr
(
"Unknown variable"
));
EXPECT_THAT
(
store
.
Lookup
(
"matrix"
,
&
blocked_matrix
),
test
::
IsErrorWithSubstr
(
"Unknown variable"
));
TF_ASSERT_OK
(
store
.
Lookup
(
"blocked"
,
&
blocked_matrix
));
}
// Tests that Close() always succeeds.
TEST
(
FakeVariableStoreTest
,
Close
)
{
FakeVariableStore
store
;
TF_EXPECT_OK
(
store
.
Close
());
store
.
AddOrDie
(
"foo"
,
{{
1.0
,
2.0
,
3.0
}});
TF_EXPECT_OK
(
store
.
Close
());
store
.
AddOrDie
(
"bar"
,
{{
1.0
,
2.0
},
{
3.0
,
4.0
}});
TF_EXPECT_OK
(
store
.
Close
());
}
// Tests that SimpleFakeVariableStore returns the user-specified mock values.
TEST
(
SimpleFakeVariableStoreTest
,
ReturnsMockedValues
)
{
SimpleFakeVariableStore
store
;
store
.
MockLookup
<
float
>
({
1
,
2
},
{{
1.0
,
2.0
}});
Matrix
<
float
>
matrix
;
TF_ASSERT_OK
(
store
.
Lookup
(
"name_doesnt_matter"
,
&
matrix
));
ASSERT_EQ
(
matrix
.
num_rows
(),
1
);
ASSERT_EQ
(
matrix
.
num_columns
(),
2
);
EXPECT_EQ
(
matrix
.
row
(
0
)[
0
],
1.0
);
EXPECT_EQ
(
matrix
.
row
(
0
)[
1
],
2.0
);
TF_ASSERT_OK
(
store
.
Close
());
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/test/helpers.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/test/helpers.h"
#include <time.h>
#include <random>
#include "dragnn/runtime/math/transformations.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
UniqueView
::
UniqueView
(
size_t
size
)
{
array_
.
Reset
(
size
);
view_
=
array_
.
view
();
}
UniqueArea
::
UniqueArea
(
size_t
num_views
,
size_t
view_size
)
{
array_
.
Reset
(
ComputeAlignedAreaSize
(
num_views
,
view_size
));
TF_CHECK_OK
(
area_
.
Reset
(
array_
.
view
(),
num_views
,
view_size
));
}
void
InitRandomVector
(
MutableVector
<
float
>
vector
)
{
// clock() is updated less frequently than a cycle counter, so keep around the
// RNG just in case we initialize some vectors in less than a clock tick.
thread_local
std
::
mt19937
*
rng
=
new
std
::
mt19937
(
clock
());
std
::
normal_distribution
<
float
>
distribution
(
0.0
,
1.0
);
for
(
int
i
=
0
;
i
<
vector
.
size
();
i
++
)
{
vector
[
i
]
=
distribution
(
*
rng
);
}
}
void
InitRandomMatrix
(
MutableMatrix
<
float
>
matrix
)
{
// See InitRandomVector comment.
thread_local
std
::
mt19937
*
rng
=
new
std
::
mt19937
(
clock
());
std
::
normal_distribution
<
float
>
distribution
(
0.0
,
1.0
);
GenerateMatrix
(
matrix
.
num_rows
(),
matrix
.
num_columns
(),
[
&
distribution
](
int
row
,
int
col
)
{
return
distribution
(
*
rng
);
},
&
matrix
);
}
void
AvxVectorFuzzTest
(
const
std
::
function
<
void
(
AvxFloatVec
*
vec
)
>
&
run
,
const
std
::
function
<
void
(
float
input_value
,
float
output_value
)
>
&
check
)
{
for
(
int
iter
=
0
;
iter
<
100
;
++
iter
)
{
UniqueVector
<
float
>
input
(
kAvxWidth
);
UniqueVector
<
float
>
output
(
kAvxWidth
);
InitRandomVector
(
*
input
);
InitRandomVector
(
*
output
);
AvxFloatVec
vec
;
vec
.
Load
(
input
->
data
());
run
(
&
vec
);
vec
.
Store
(
output
->
data
());
for
(
int
i
=
0
;
i
<
kAvxWidth
;
++
i
)
{
check
((
*
input
)[
i
],
(
*
output
)[
i
]);
}
}
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/test/helpers.h
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Helpers to make it less painful to create instances of aligned values.
// Intended for testing or benchmarking; production code should use managed
// memory allocation, for example Operands.
#ifndef DRAGNN_RUNTIME_TEST_HELPERS_H_
#define DRAGNN_RUNTIME_TEST_HELPERS_H_
#include <stddef.h>
#include <algorithm>
#include <functional>
#include <vector>
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/math/avx_vector_array.h"
#include "dragnn/runtime/math/types.h"
#include <gmock/gmock.h>
#include "tensorflow/core/platform/logging.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
// An aligned view and its uniquely-owned underlying storage. Can be used like
// a std::unique_ptr<MutableAlignedView>.
class
UniqueView
{
public:
// Creates a view of |size| uninitialized bytes.
explicit
UniqueView
(
size_t
size
);
// Provides std::unique_ptr-like access.
MutableAlignedView
*
get
()
{
return
&
view_
;
}
MutableAlignedView
&
operator
*
()
{
return
view_
;
}
MutableAlignedView
*
operator
->
()
{
return
&
view_
;
}
private:
// View and its underlying storage.
UniqueAlignedArray
array_
;
MutableAlignedView
view_
;
};
// An aligned area and its uniquely-owned underlying storage. Can be used like
// a std::unique_ptr<MutableAlignedArea>.
class
UniqueArea
{
public:
// Creates an area with |num_views| sub-views, each of which has |view_size|
// uninitialized bytes. Check-fails on error.
UniqueArea
(
size_t
num_views
,
size_t
view_size
);
// Provides std::unique_ptr-like access.
MutableAlignedArea
*
get
()
{
return
&
area_
;
}
MutableAlignedArea
&
operator
*
()
{
return
area_
;
}
MutableAlignedArea
*
operator
->
()
{
return
&
area_
;
}
private:
// Area and its underlying storage.
UniqueAlignedArray
array_
;
MutableAlignedArea
area_
;
};
// A vector and its uniquely-owned underlying storage. Can be used like a
// std::unique_ptr<MutableVector<T>>.
template
<
class
T
>
class
UniqueVector
{
public:
// Creates an empty vector.
UniqueVector
()
:
UniqueVector
(
0
)
{}
// Creates a vector with |dimension| uninitialized Ts.
explicit
UniqueVector
(
size_t
dimension
)
:
view_
(
dimension
*
sizeof
(
T
)),
vector_
(
*
view_
)
{}
// Creates a vector initialized to hold the |values|.
explicit
UniqueVector
(
const
std
::
vector
<
T
>
&
values
);
// Provides std::unique_ptr-like access.
MutableVector
<
T
>
*
get
()
{
return
&
vector_
;
}
MutableVector
<
T
>
&
operator
*
()
{
return
vector_
;
}
MutableVector
<
T
>
*
operator
->
()
{
return
&
vector_
;
}
// Returns a view pointing to the same memory.
MutableAlignedView
view
()
{
return
*
view_
;
}
private:
// Vector and its underlying view.
UniqueView
view_
;
MutableVector
<
T
>
vector_
;
};
// A matrix and its uniquely-owned underlying storage. Can be used like a
// std::unique_ptr<MutableMatrix<T>>>.
template
<
class
T
>
class
UniqueMatrix
{
public:
// Creates an empty matrix.
UniqueMatrix
()
:
UniqueMatrix
(
0
,
0
)
{}
// Creates a matrix with |num_rows| x |num_columns| uninitialized Ts.
UniqueMatrix
(
size_t
num_rows
,
size_t
num_columns
)
:
area_
(
num_rows
,
num_columns
*
sizeof
(
T
)),
matrix_
(
*
area_
)
{}
// Creates a matrix initialized to hold the |values|.
explicit
UniqueMatrix
(
const
std
::
vector
<
std
::
vector
<
T
>>
&
values
);
// Provides std::unique_ptr-like access.
MutableMatrix
<
T
>
*
get
()
{
return
&
matrix_
;
}
MutableMatrix
<
T
>
&
operator
*
()
{
return
matrix_
;
}
MutableMatrix
<
T
>
*
operator
->
()
{
return
&
matrix_
;
}
// Returns an area pointing to the same memory.
MutableAlignedArea
area
()
{
return
*
area_
;
}
private:
// Matrix and its underlying area.
UniqueArea
area_
;
MutableMatrix
<
T
>
matrix_
;
};
// Implementation details below.
template
<
class
T
>
UniqueVector
<
T
>::
UniqueVector
(
const
std
::
vector
<
T
>
&
values
)
:
UniqueVector
(
values
.
size
())
{
std
::
copy
(
values
.
begin
(),
values
.
end
(),
vector_
.
begin
());
}
template
<
class
T
>
UniqueMatrix
<
T
>::
UniqueMatrix
(
const
std
::
vector
<
std
::
vector
<
T
>>
&
values
)
:
UniqueMatrix
(
values
.
size
(),
values
.
empty
()
?
0
:
values
[
0
].
size
())
{
for
(
size_t
i
=
0
;
i
<
values
.
size
();
++
i
)
{
CHECK_EQ
(
values
[
0
].
size
(),
values
[
i
].
size
());
std
::
copy
(
values
[
i
].
begin
(),
values
[
i
].
end
(),
matrix_
.
row
(
i
).
begin
());
}
}
// Expects that the |matrix| contains the |data|.
template
<
class
T
>
void
ExpectMatrix
(
Matrix
<
T
>
matrix
,
const
std
::
vector
<
std
::
vector
<
T
>>
&
data
)
{
ASSERT_EQ
(
matrix
.
num_rows
(),
data
.
size
());
if
(
data
.
empty
())
return
;
ASSERT_EQ
(
matrix
.
num_columns
(),
data
[
0
].
size
());
for
(
size_t
row
=
0
;
row
<
data
.
size
();
++
row
)
{
for
(
size_t
column
=
0
;
column
<
data
[
row
].
size
();
++
column
)
{
EXPECT_EQ
(
matrix
.
row
(
row
)[
column
],
data
[
row
][
column
]);
}
}
}
// Initializes a floating-point vector with random values, using a normal
// distribution centered at 0 with standard deviation 1.
void
InitRandomVector
(
MutableVector
<
float
>
vector
);
void
InitRandomMatrix
(
MutableMatrix
<
float
>
matrix
);
// Fuzz test using AVX vectors.
// If this file gets too big, move into something like math/test_helpers.h.
void
AvxVectorFuzzTest
(
const
std
::
function
<
void
(
AvxFloatVec
*
vec
)
>
&
run
,
const
std
::
function
<
void
(
float
input_value
,
float
output_value
)
>
&
check
);
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_RUNTIME_TEST_HELPERS_H_
research/syntaxnet/dragnn/runtime/test/helpers_test.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/test/helpers.h"
#include <string>
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/math/types.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
// Fills the |slice| with the |value|. Slice must have .data() and .size().
template
<
class
Slice
,
class
T
>
void
Fill
(
Slice
slice
,
T
value
)
{
for
(
size_t
i
=
0
;
i
<
slice
.
size
();
++
i
)
slice
.
data
()[
i
]
=
value
;
}
// Returns the sum of all elements in the |slice|, casted to double. Slice must
// have .data() and .size().
template
<
class
Slice
>
double
Sum
(
Slice
slice
)
{
double
sum
=
0.0
;
for
(
size_t
i
=
0
;
i
<
slice
.
size
();
++
i
)
{
sum
+=
static_cast
<
double
>
(
slice
.
data
()[
i
]);
}
return
sum
;
}
// Expects that the two pointers have the same address.
void
ExpectSameAddress
(
const
void
*
pointer1
,
const
void
*
pointer2
)
{
EXPECT_EQ
(
pointer1
,
pointer2
);
}
// Tests that each byte of a UniqueView is usable.
TEST
(
UniqueViewTest
,
Usable
)
{
UniqueView
view
(
100
);
EXPECT_EQ
(
view
->
size
(),
100
);
Fill
(
*
view
,
'x'
);
LOG
(
INFO
)
<<
"Prevents elision by optimizer: "
<<
Sum
(
*
view
);
EXPECT_EQ
(
view
->
data
()[
0
],
'x'
);
}
// Tests that each byte of a UniqueArea is usable.
TEST
(
UniqueAreaTest
,
Usable
)
{
UniqueArea
area
(
10
,
100
);
EXPECT_EQ
(
area
->
num_views
(),
10
);
EXPECT_EQ
(
area
->
view_size
(),
100
);
for
(
size_t
i
=
0
;
i
<
10
;
++
i
)
{
Fill
(
area
->
view
(
i
),
'y'
);
LOG
(
INFO
)
<<
"Prevents elision by optimizer: "
<<
Sum
(
area
->
view
(
i
));
EXPECT_EQ
(
area
->
view
(
i
).
data
()[
0
],
'y'
);
}
}
// Tests that UniqueVector is empty by default.
TEST
(
UniqueVectorTest
,
EmptyByDefault
)
{
UniqueVector
<
float
>
vector
;
EXPECT_EQ
(
vector
->
size
(),
0
);
}
// Tests that each element of a UniqueVector is usable.
TEST
(
UniqueVectorTest
,
Usable
)
{
UniqueVector
<
float
>
vector
(
100
);
EXPECT_EQ
(
vector
->
size
(),
100
);
Fill
(
*
vector
,
1.5
);
LOG
(
INFO
)
<<
"Prevents elision by optimizer: "
<<
Sum
(
*
vector
);
EXPECT_EQ
((
*
vector
)[
0
],
1.5
);
}
// Tests that UniqueVector also exports a view.
TEST
(
UniqueVectorTest
,
View
)
{
UniqueVector
<
float
>
vector
(
123
);
ExpectSameAddress
(
vector
.
view
().
data
(),
vector
->
data
());
EXPECT_EQ
(
vector
.
view
().
size
(),
123
*
sizeof
(
float
));
}
// Tests that a UniqueVector can be constructed with an initial value.
TEST
(
UniqueVectorTest
,
Initialization
)
{
UniqueVector
<
int
>
vector
({
2
,
3
,
5
,
7
});
EXPECT_EQ
(
vector
->
size
(),
4
);
EXPECT_EQ
((
*
vector
)[
0
],
2
);
EXPECT_EQ
((
*
vector
)[
1
],
3
);
EXPECT_EQ
((
*
vector
)[
2
],
5
);
EXPECT_EQ
((
*
vector
)[
3
],
7
);
}
// Tests that UniqueMatrix is empty by default.
TEST
(
UniqueMatrixTest
,
EmptyByDefault
)
{
UniqueMatrix
<
float
>
row_major_matrix
;
EXPECT_EQ
(
row_major_matrix
->
num_rows
(),
0
);
EXPECT_EQ
(
row_major_matrix
->
num_columns
(),
0
);
}
// Tests that each element of a UniqueMatrix is usable.
TEST
(
UniqueMatrixTest
,
Usable
)
{
UniqueMatrix
<
float
>
row_major_matrix
(
10
,
100
);
EXPECT_EQ
(
row_major_matrix
->
num_rows
(),
10
);
EXPECT_EQ
(
row_major_matrix
->
num_columns
(),
100
);
for
(
size_t
i
=
0
;
i
<
10
;
++
i
)
{
Fill
(
row_major_matrix
->
row
(
i
),
1.75
);
LOG
(
INFO
)
<<
"Prevents elision by optimizer: "
<<
Sum
(
row_major_matrix
->
row
(
i
));
EXPECT_EQ
(
row_major_matrix
->
row
(
i
)[
0
],
1.75
);
}
}
// Tests that UniqueMatrix also exports an area.
TEST
(
UniqueMatrixTest
,
Area
)
{
UniqueMatrix
<
float
>
row_major_matrix
(
12
,
34
);
ExpectSameAddress
(
row_major_matrix
.
area
().
view
(
0
).
data
(),
row_major_matrix
->
row
(
0
).
data
());
EXPECT_EQ
(
row_major_matrix
.
area
().
num_views
(),
12
);
EXPECT_EQ
(
row_major_matrix
.
area
().
view_size
(),
34
*
sizeof
(
float
));
}
// Tests that a UniqueMatrix can be constructed with an initial value.
TEST
(
UniqueMatrixTest
,
Initialization
)
{
UniqueMatrix
<
int
>
row_major_matrix
({{
2
,
3
,
5
},
{
7
,
11
,
13
}});
EXPECT_EQ
(
row_major_matrix
->
num_rows
(),
2
);
EXPECT_EQ
(
row_major_matrix
->
num_columns
(),
3
);
EXPECT_EQ
(
row_major_matrix
->
row
(
0
)[
0
],
2
);
EXPECT_EQ
(
row_major_matrix
->
row
(
0
)[
1
],
3
);
EXPECT_EQ
(
row_major_matrix
->
row
(
0
)[
2
],
5
);
EXPECT_EQ
(
row_major_matrix
->
row
(
1
)[
0
],
7
);
EXPECT_EQ
(
row_major_matrix
->
row
(
1
)[
1
],
11
);
EXPECT_EQ
(
row_major_matrix
->
row
(
1
)[
2
],
13
);
}
}
// namespace
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/runtime/test/network_test_base.cc
deleted
100644 → 0
View file @
a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/test/network_test_base.h"
#include "dragnn/protos/data.pb.h"
#include "dragnn/runtime/flexible_matrix_kernel.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
runtime
{
namespace
{
using
::
testing
::
InSequence
;
using
::
testing
::
Return
;
// Fills the |matrix| with the |fill_value|.
void
Fill
(
float
fill_value
,
MutableMatrix
<
float
>
matrix
)
{
for
(
size_t
i
=
0
;
i
<
matrix
.
num_rows
();
++
i
)
{
for
(
float
&
value
:
matrix
.
row
(
i
))
value
=
fill_value
;
}
}
}
// namespace
constexpr
char
NetworkTestBase
::
kTestComponentName
[];
void
NetworkTestBase
::
TearDown
()
{
// The state extensions may contain objects that cannot outlive the component,
// so discard the extensions early. This is not an issue in real-world usage,
// as the Master calls destructors in the right order.
session_state_
.
extensions
=
Extensions
();
}
NetworkTestBase
::
GetInputFeaturesFunctor
NetworkTestBase
::
ExtractFeatures
(
int
expected_channel_id
,
const
std
::
vector
<
Feature
>
&
features
)
{
return
[
=
](
const
string
&
component_name
,
std
::
function
<
int32
*
(
int
)
>
allocate_indices
,
std
::
function
<
int64
*
(
int
)
>
allocate_ids
,
std
::
function
<
float
*
(
int
)
>
allocate_weights
,
int
channel_id
)
{
EXPECT_EQ
(
component_name
,
kTestComponentName
);
EXPECT_EQ
(
channel_id
,
expected_channel_id
);
const
int
num_features
=
features
.
size
();
int32
*
indices
=
allocate_indices
(
num_features
);
int64
*
ids
=
allocate_ids
(
num_features
);
float
*
weights
=
allocate_weights
(
num_features
);
for
(
int
i
=
0
;
i
<
num_features
;
++
i
)
{
indices
[
i
]
=
features
[
i
].
index
;
ids
[
i
]
=
features
[
i
].
id
;
weights
[
i
]
=
features
[
i
].
weight
;
}
return
num_features
;
};
}
NetworkTestBase
::
GetTranslatedLinkFeaturesFunctor
NetworkTestBase
::
ExtractLinks
(
int
expected_channel_id
,
const
std
::
vector
<
string
>
&
features_text
)
{
std
::
vector
<
LinkFeatures
>
features
;
for
(
const
string
&
text
:
features_text
)
{
features
.
emplace_back
();
CHECK
(
TextFormat
::
ParseFromString
(
text
,
&
features
.
back
()));
}
return
[
=
](
const
string
&
component_name
,
int
channel_id
)
{
EXPECT_EQ
(
component_name
,
kTestComponentName
);
EXPECT_EQ
(
channel_id
,
expected_channel_id
);
return
features
;
};
}
void
NetworkTestBase
::
AddVectorVariable
(
const
string
&
name
,
size_t
dimension
,
float
fill_value
)
{
const
std
::
vector
<
float
>
row
(
dimension
,
fill_value
);
const
std
::
vector
<
std
::
vector
<
float
>>
values
(
1
,
row
);
variable_store_
.
AddOrDie
(
name
,
values
);
}
void
NetworkTestBase
::
AddMatrixVariable
(
const
string
&
name
,
size_t
num_rows
,
size_t
num_columns
,
float
fill_value
)
{
const
std
::
vector
<
float
>
row
(
num_columns
,
fill_value
);
const
std
::
vector
<
std
::
vector
<
float
>>
values
(
num_rows
,
row
);
variable_store_
.
AddOrDie
(
name
,
values
);
}
void
NetworkTestBase
::
AddFixedEmbeddingMatrix
(
int
channel_id
,
size_t
vocabulary_size
,
size_t
embedding_dim
,
float
fill_value
)
{
const
string
name
=
tensorflow
::
strings
::
StrCat
(
kTestComponentName
,
"/fixed_embedding_matrix_"
,
channel_id
,
"/trimmed"
);
AddMatrixVariable
(
name
,
vocabulary_size
,
embedding_dim
,
fill_value
);
}
void
NetworkTestBase
::
AddLinkedWeightMatrix
(
int
channel_id
,
size_t
source_dim
,
size_t
embedding_dim
,
float
fill_value
)
{
const
string
name
=
tensorflow
::
strings
::
StrCat
(
kTestComponentName
,
"/linked_embedding_matrix_"
,
channel_id
,
"/weights"
,
FlexibleMatrixKernel
::
kSuffix
);
AddMatrixVariable
(
name
,
embedding_dim
,
source_dim
,
fill_value
);
}
void
NetworkTestBase
::
AddLinkedOutOfBoundsVector
(
int
channel_id
,
size_t
embedding_dim
,
float
fill_value
)
{
const
string
name
=
tensorflow
::
strings
::
StrCat
(
kTestComponentName
,
"/linked_embedding_matrix_"
,
channel_id
,
"/out_of_bounds"
);
AddVectorVariable
(
name
,
embedding_dim
,
fill_value
);
}
void
NetworkTestBase
::
AddComponent
(
const
string
&
component_name
)
{
TF_ASSERT_OK
(
network_state_manager_
.
AddComponent
(
component_name
));
}
void
NetworkTestBase
::
AddLayer
(
const
string
&
layer_name
,
size_t
dimension
)
{
LayerHandle
<
float
>
unused_layer_handle
;
TF_ASSERT_OK
(
network_state_manager_
.
AddLayer
(
layer_name
,
dimension
,
&
unused_layer_handle
));
}
void
NetworkTestBase
::
AddPairwiseLayer
(
const
string
&
layer_name
,
size_t
dimension
)
{
PairwiseLayerHandle
<
float
>
unused_layer_handle
;
TF_ASSERT_OK
(
network_state_manager_
.
AddLayer
(
layer_name
,
dimension
,
&
unused_layer_handle
));
}
void
NetworkTestBase
::
StartComponent
(
size_t
num_steps
)
{
// The pre-allocation hint is arbitrary, but setting it to a small value
// exercises reallocations.
TF_ASSERT_OK
(
network_states_
.
StartNextComponent
(
5
));
for
(
size_t
i
=
0
;
i
<
num_steps
;
++
i
)
network_states_
.
AddStep
();
}
MutableMatrix
<
float
>
NetworkTestBase
::
GetLayer
(
const
string
&
component_name
,
const
string
&
layer_name
)
const
{
size_t
unused_dimension
=
0
;
LayerHandle
<
float
>
handle
;
TF_CHECK_OK
(
network_state_manager_
.
LookupLayer
(
component_name
,
layer_name
,
&
unused_dimension
,
&
handle
));
return
network_states_
.
GetLayer
(
handle
);
}
MutableMatrix
<
float
>
NetworkTestBase
::
GetPairwiseLayer
(
const
string
&
component_name
,
const
string
&
layer_name
)
const
{
size_t
unused_dimension
=
0
;
PairwiseLayerHandle
<
float
>
handle
;
TF_CHECK_OK
(
network_state_manager_
.
LookupLayer
(
component_name
,
layer_name
,
&
unused_dimension
,
&
handle
));
return
network_states_
.
GetLayer
(
handle
);
}
void
NetworkTestBase
::
FillLayer
(
const
string
&
component_name
,
const
string
&
layer_name
,
float
fill_value
)
const
{
Fill
(
fill_value
,
GetLayer
(
component_name
,
layer_name
));
}
void
NetworkTestBase
::
SetupTransitionLoop
(
size_t
num_steps
)
{
// Return not terminal |num_steps| times, then return terminal.
InSequence
scoped
;
EXPECT_CALL
(
compute_session_
,
IsTerminal
(
kTestComponentName
))
.
Times
(
num_steps
)
.
WillRepeatedly
(
Return
(
false
))
.
RetiresOnSaturation
();
EXPECT_CALL
(
compute_session_
,
IsTerminal
(
kTestComponentName
))
.
WillOnce
(
Return
(
true
));
}
void
NetworkTestBase
::
ExpectVector
(
Vector
<
float
>
vector
,
size_t
dimension
,
float
expected_value
)
{
ASSERT_EQ
(
vector
.
size
(),
dimension
);
for
(
const
float
value
:
vector
)
EXPECT_EQ
(
value
,
expected_value
);
}
void
NetworkTestBase
::
ExpectMatrix
(
Matrix
<
float
>
matrix
,
size_t
num_rows
,
size_t
num_columns
,
float
expected_value
)
{
ASSERT_EQ
(
matrix
.
num_rows
(),
num_rows
);
ASSERT_EQ
(
matrix
.
num_columns
(),
num_columns
);
for
(
size_t
row
=
0
;
row
<
num_rows
;
++
row
)
{
ExpectVector
(
matrix
.
row
(
row
),
num_columns
,
expected_value
);
}
}
}
// namespace runtime
}
// namespace dragnn
}
// namespace syntaxnet
Prev
1
…
7
8
9
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment