Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
80178fc6
Unverified
Commit
80178fc6
authored
May 11, 2018
by
Mark Omernick
Committed by
GitHub
May 11, 2018
Browse files
Merge pull request #4153 from terryykoo/master
Export @195097388.
parents
a84e1ef9
edea2b67
Changes
145
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
604 additions
and
196 deletions
+604
-196
research/syntaxnet/syntaxnet/feature_extractor.h
research/syntaxnet/syntaxnet/feature_extractor.h
+1
-0
research/syntaxnet/syntaxnet/fml_parser.h
research/syntaxnet/syntaxnet/fml_parser.h
+4
-0
research/syntaxnet/syntaxnet/fml_parser_test.cc
research/syntaxnet/syntaxnet/fml_parser_test.cc
+77
-0
research/syntaxnet/syntaxnet/graph_builder.py
research/syntaxnet/syntaxnet/graph_builder.py
+3
-1
research/syntaxnet/syntaxnet/graph_builder_test.py
research/syntaxnet/syntaxnet/graph_builder_test.py
+7
-14
research/syntaxnet/syntaxnet/lexicon_builder_test.py
research/syntaxnet/syntaxnet/lexicon_builder_test.py
+9
-15
research/syntaxnet/syntaxnet/ops/parser_ops.cc
research/syntaxnet/syntaxnet/ops/parser_ops.cc
+98
-0
research/syntaxnet/syntaxnet/ops/shape_helpers.h
research/syntaxnet/syntaxnet/ops/shape_helpers.h
+74
-0
research/syntaxnet/syntaxnet/parser_eval.py
research/syntaxnet/syntaxnet/parser_eval.py
+3
-2
research/syntaxnet/syntaxnet/parser_features.cc
research/syntaxnet/syntaxnet/parser_features.cc
+0
-18
research/syntaxnet/syntaxnet/parser_trainer.py
research/syntaxnet/syntaxnet/parser_trainer.py
+3
-2
research/syntaxnet/syntaxnet/reader_ops.cc
research/syntaxnet/syntaxnet/reader_ops.cc
+14
-1
research/syntaxnet/syntaxnet/reader_ops_test.py
research/syntaxnet/syntaxnet/reader_ops_test.py
+55
-21
research/syntaxnet/syntaxnet/registry.cc
research/syntaxnet/syntaxnet/registry.cc
+37
-0
research/syntaxnet/syntaxnet/registry.h
research/syntaxnet/syntaxnet/registry.h
+46
-7
research/syntaxnet/syntaxnet/registry_test.cc
research/syntaxnet/syntaxnet/registry_test.cc
+95
-0
research/syntaxnet/syntaxnet/shared_store.h
research/syntaxnet/syntaxnet/shared_store.h
+1
-1
research/syntaxnet/syntaxnet/structured_graph_builder.py
research/syntaxnet/syntaxnet/structured_graph_builder.py
+2
-3
research/syntaxnet/syntaxnet/syntaxnet.bzl
research/syntaxnet/syntaxnet/syntaxnet.bzl
+21
-96
research/syntaxnet/syntaxnet/term_frequency_map.cc
research/syntaxnet/syntaxnet/term_frequency_map.cc
+54
-15
No files found.
research/syntaxnet/syntaxnet/feature_extractor.h
View file @
80178fc6
...
...
@@ -250,6 +250,7 @@ class GenericFeatureFunction {
string
GetParameter
(
const
string
&
name
)
const
;
int
GetIntParameter
(
const
string
&
name
,
int
default_value
)
const
;
bool
GetBoolParameter
(
const
string
&
name
,
bool
default_value
)
const
;
double
GetFloatParameter
(
const
string
&
name
,
double
default_value
)
const
;
// Returns the FML function description for the feature function, i.e. the
// name and parameters without the nested features.
...
...
research/syntaxnet/syntaxnet/fml_parser.h
View file @
80178fc6
...
...
@@ -108,6 +108,10 @@ class FMLParser {
string
item_text_
;
};
// Returns the |function| or |extractor| descriptor as an FML string.
string
AsFML
(
const
FeatureFunctionDescriptor
&
function
);
string
AsFML
(
const
FeatureExtractorDescriptor
&
extractor
);
}
// namespace syntaxnet
#endif // SYNTAXNET_FML_PARSER_H_
research/syntaxnet/syntaxnet/fml_parser_test.cc
0 → 100644
View file @
80178fc6
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "syntaxnet/fml_parser.h"
#include <string>
#include <vector>
#include "syntaxnet/base.h"
#include "syntaxnet/feature_extractor.pb.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
{
// Returns the list of lines in the |text|. Also strips trailing whitespace
// from each line, since the FML generator sometimes appends trailing spaces.
std
::
vector
<
string
>
LinesOf
(
const
string
&
text
)
{
std
::
vector
<
string
>
lines
=
tensorflow
::
str_util
::
Split
(
text
,
"
\n
"
,
tensorflow
::
str_util
::
SkipEmpty
());
for
(
string
&
line
:
lines
)
{
tensorflow
::
str_util
::
StripTrailingWhitespace
(
&
line
);
}
return
lines
;
}
// Tests that a single function can be round-trip converted from FML to
// descriptor protos and back to FML.
TEST
(
FMLParserTest
,
RoundTripSingleFunction
)
{
FeatureExtractorDescriptor
extractor
;
FMLParser
().
Parse
(
"offset(1).input.token.word(min-freq=10)"
,
&
extractor
);
EXPECT_EQ
(
LinesOf
(
AsFML
(
extractor
)),
LinesOf
(
"offset(1).input.token.word(min-freq=
\"
10
\"
)"
));
// Also check each individual feature function.
EXPECT_EQ
(
AsFML
(
extractor
.
feature
(
0
)),
"offset(1).input.token.word(min-freq=
\"
10
\"
)"
);
EXPECT_EQ
(
AsFML
(
extractor
.
feature
(
0
).
feature
(
0
)),
"input.token.word(min-freq=
\"
10
\"
)"
);
EXPECT_EQ
(
AsFML
(
extractor
.
feature
(
0
).
feature
(
0
).
feature
(
0
)),
"token.word(min-freq=
\"
10
\"
)"
);
EXPECT_EQ
(
AsFML
(
extractor
.
feature
(
0
).
feature
(
0
).
feature
(
0
).
feature
(
0
)),
"word(min-freq=
\"
10
\"
)"
);
}
// Tests that a set of functions can be round-trip converted from FML to
// descriptor protos and back to FML.
TEST
(
FMLParserTest
,
RoundTripMultipleFunctions
)
{
FeatureExtractorDescriptor
extractor
;
FMLParser
().
Parse
(
R"(offset(1).word(max-num-terms=987)
input { tag(outside=false) label }
pairs { stack.tag input.tag input.child(-1).label })"
,
&
extractor
);
// Note that AsFML() adds quotes to all feature option values.
EXPECT_EQ
(
LinesOf
(
AsFML
(
extractor
)),
LinesOf
(
"offset(1).word(max-num-terms=
\"
987
\"
)
\n
"
"input { tag(outside=
\"
false
\"
) label }
\n
"
"pairs { stack.tag input.tag input.child(-1).label }"
));
}
}
// namespace
}
// namespace syntaxnet
research/syntaxnet/syntaxnet/graph_builder.py
View file @
80178fc6
...
...
@@ -22,6 +22,7 @@ import syntaxnet.load_parser_ops
from
tensorflow.python.ops
import
control_flow_ops
as
cf
from
tensorflow.python.ops
import
state_ops
from
tensorflow.python.platform
import
tf_logging
as
logging
from
tensorflow.python.training
import
saver
as
tf_saver
from
syntaxnet.ops
import
gen_parser_ops
...
...
@@ -572,5 +573,6 @@ class GreedyParser(object):
for
key
in
variables_to_save
.
keys
():
if
not
key
.
endswith
(
'avg_var'
):
del
variables_to_save
[
key
]
self
.
saver
=
tf
.
train
.
Saver
(
variables_to_save
)
self
.
saver
=
tf
.
train
.
Saver
(
variables_to_save
,
builder
=
tf_saver
.
BaseSaverBuilder
())
return
self
.
saver
research/syntaxnet/syntaxnet/graph_builder_test.py
View file @
80178fc6
...
...
@@ -20,33 +20,26 @@
import
os.path
import
tensorflow
as
tf
from
tensorflow.python.framework
import
test_util
from
tensorflow.python.ops
import
variables
from
tensorflow.python.platform
import
googletest
from
syntaxnet
import
graph_builder
from
syntaxnet
import
sparse_pb2
from
syntaxnet
import
test_flags
from
syntaxnet.ops
import
gen_parser_ops
FLAGS
=
tf
.
app
.
flags
.
FLAGS
if
not
hasattr
(
FLAGS
,
'test_srcdir'
):
FLAGS
.
test_srcdir
=
''
if
not
hasattr
(
FLAGS
,
'test_tmpdir'
):
FLAGS
.
test_tmpdir
=
tf
.
test
.
get_temp_dir
()
class
GraphBuilderTest
(
test_util
.
TensorFlowTestCase
):
class
GraphBuilderTest
(
tf
.
test
.
TestCase
):
def
setUp
(
self
):
# Creates a task context with the correct testing paths.
initial_task_context
=
os
.
path
.
join
(
FLAGS
.
test_srcdir
,
initial_task_context
=
os
.
path
.
join
(
test_flags
.
source_root
()
,
'syntaxnet/'
'testdata/context.pbtxt'
)
self
.
_task_context
=
os
.
path
.
join
(
FLAGS
.
test_t
mpdir
,
'context.pbtxt'
)
self
.
_task_context
=
os
.
path
.
join
(
test_flags
.
te
mp
_
dir
()
,
'context.pbtxt'
)
with
open
(
initial_task_context
,
'r'
)
as
fin
:
with
open
(
self
.
_task_context
,
'w'
)
as
fout
:
fout
.
write
(
fin
.
read
().
replace
(
'SRCDIR'
,
FLAGS
.
test_srcdir
)
.
replace
(
'OUTPATH'
,
FLAGS
.
test_t
mpdir
))
fout
.
write
(
fin
.
read
().
replace
(
'SRCDIR'
,
test_flags
.
source_root
()
)
.
replace
(
'OUTPATH'
,
test_flags
.
te
mp
_
dir
()
))
# Creates necessary term maps.
with
self
.
test_session
()
as
sess
:
...
...
@@ -320,4 +313,4 @@ class GraphBuilderTest(test_util.TensorFlowTestCase):
if
__name__
==
'__main__'
:
google
test
.
main
()
tf
.
test
.
main
()
research/syntaxnet/syntaxnet/lexicon_builder_test.py
View file @
80178fc6
...
...
@@ -23,16 +23,13 @@ import tensorflow as tf
import
syntaxnet.load_parser_ops
from
tensorflow.python.framework
import
test_util
from
tensorflow.python.platform
import
googletest
from
tensorflow.python.platform
import
tf_logging
as
logging
from
syntaxnet
import
sentence_pb2
from
syntaxnet
import
task_spec_pb2
from
syntaxnet
import
test_flags
from
syntaxnet.ops
import
gen_parser_ops
FLAGS
=
tf
.
app
.
flags
.
FLAGS
CONLL_DOC1
=
u
'''1 बात _ n NN _ _ _ _ _
2 गलत _ adj JJ _ _ _ _ _
3 हो _ v VM _ _ _ _ _
...
...
@@ -75,15 +72,11 @@ CHAR_NGRAMS = u'''^ अ ^ अभ ^ आ ^ आन ^ इ ^ इस $ ^ क ^
COMMENTS
=
u
'# Line with fake comments.'
class
LexiconBuilderTest
(
t
est_util
.
TensorFlow
TestCase
):
class
LexiconBuilderTest
(
t
f
.
test
.
TestCase
):
def
setUp
(
self
):
if
not
hasattr
(
FLAGS
,
'test_srcdir'
):
FLAGS
.
test_srcdir
=
''
if
not
hasattr
(
FLAGS
,
'test_tmpdir'
):
FLAGS
.
test_tmpdir
=
tf
.
test
.
get_temp_dir
()
self
.
corpus_file
=
os
.
path
.
join
(
FLAGS
.
test_tmpdir
,
'documents.conll'
)
self
.
context_file
=
os
.
path
.
join
(
FLAGS
.
test_tmpdir
,
'context.pbtxt'
)
self
.
corpus_file
=
os
.
path
.
join
(
test_flags
.
temp_dir
(),
'documents.conll'
)
self
.
context_file
=
os
.
path
.
join
(
test_flags
.
temp_dir
(),
'context.pbtxt'
)
def
AddInput
(
self
,
name
,
file_pattern
,
record_format
,
context
):
inp
=
context
.
input
.
add
()
...
...
@@ -106,7 +99,8 @@ class LexiconBuilderTest(test_util.TensorFlowTestCase):
'category-map'
,
'label-map'
,
'prefix-table'
,
'suffix-table'
,
'tag-to-category'
,
'char-map'
,
'char-ngram-map'
):
self
.
AddInput
(
name
,
os
.
path
.
join
(
FLAGS
.
test_tmpdir
,
name
),
''
,
context
)
self
.
AddInput
(
name
,
os
.
path
.
join
(
test_flags
.
temp_dir
(),
name
),
''
,
context
)
logging
.
info
(
'Writing context to: %s'
,
self
.
context_file
)
with
open
(
self
.
context_file
,
'w'
)
as
f
:
f
.
write
(
str
(
context
))
...
...
@@ -140,7 +134,7 @@ class LexiconBuilderTest(test_util.TensorFlowTestCase):
self
.
assertTrue
(
last
)
def
ValidateTagToCategoryMap
(
self
):
with
open
(
os
.
path
.
join
(
FLAGS
.
test_t
mpdir
,
'tag-to-category'
),
'r'
)
as
f
:
with
open
(
os
.
path
.
join
(
test_flags
.
te
mp
_
dir
()
,
'tag-to-category'
),
'r'
)
as
f
:
entries
=
[
line
.
strip
().
split
(
'
\t
'
)
for
line
in
f
.
readlines
()]
for
tag
,
category
in
entries
:
self
.
assertIn
(
tag
,
TAGS
)
...
...
@@ -148,7 +142,7 @@ class LexiconBuilderTest(test_util.TensorFlowTestCase):
def
LoadMap
(
self
,
map_name
):
loaded_map
=
{}
with
open
(
os
.
path
.
join
(
FLAGS
.
test_t
mpdir
,
map_name
),
'r'
)
as
f
:
with
open
(
os
.
path
.
join
(
test_flags
.
te
mp
_
dir
()
,
map_name
),
'r'
)
as
f
:
for
line
in
f
:
entries
=
line
.
strip
().
split
(
' '
)
if
len
(
entries
)
>=
2
:
...
...
@@ -237,4 +231,4 @@ class LexiconBuilderTest(test_util.TensorFlowTestCase):
if
__name__
==
'__main__'
:
google
test
.
main
()
tf
.
test
.
main
()
research/syntaxnet/syntaxnet/ops/parser_ops.cc
View file @
80178fc6
...
...
@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "syntaxnet/ops/shape_helpers.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace
syntaxnet
{
...
...
@@ -29,6 +31,14 @@ REGISTER_OP("GoldParseReader")
.
Attr
(
"corpus_name: string='documents'"
)
.
Attr
(
"arg_prefix: string='brain_parser'"
)
.
SetIsStateful
()
.
SetShapeFn
([](
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
int
feature_size
;
TF_RETURN_IF_ERROR
(
context
->
GetAttr
(
"feature_size"
,
&
feature_size
));
for
(
int
i
=
0
;
i
<
feature_size
;
++
i
)
MatrixOutputShape
(
i
,
context
);
ScalarOutputShape
(
feature_size
,
context
);
VectorOutputShape
(
feature_size
+
1
,
context
);
return
tensorflow
::
Status
::
OK
();
})
.
Doc
(
R"doc(
Reads sentences, parses them, and returns (gold action, feature) pairs.
...
...
@@ -55,6 +65,15 @@ REGISTER_OP("DecodedParseReader")
.
Attr
(
"corpus_name: string='documents'"
)
.
Attr
(
"arg_prefix: string='brain_parser'"
)
.
SetIsStateful
()
.
SetShapeFn
([](
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
int
feature_size
;
TF_RETURN_IF_ERROR
(
context
->
GetAttr
(
"feature_size"
,
&
feature_size
));
for
(
int
i
=
0
;
i
<
feature_size
;
++
i
)
MatrixOutputShape
(
i
,
context
);
ScalarOutputShape
(
feature_size
,
context
);
context
->
set_output
(
feature_size
+
1
,
context
->
Vector
(
2
));
VectorOutputShape
(
feature_size
+
2
,
context
);
return
MatrixInputShape
(
0
,
context
);
})
.
Doc
(
R"doc(
Reads sentences and parses them taking parsing transitions based on the
input transition scores.
...
...
@@ -85,6 +104,14 @@ REGISTER_OP("BeamParseReader")
.
Attr
(
"continue_until_all_final: bool=false"
)
.
Attr
(
"always_start_new_sentences: bool=false"
)
.
SetIsStateful
()
.
SetShapeFn
([](
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
int
feature_size
;
TF_RETURN_IF_ERROR
(
context
->
GetAttr
(
"feature_size"
,
&
feature_size
));
for
(
int
i
=
0
;
i
<
feature_size
;
++
i
)
MatrixOutputShape
(
i
,
context
);
ScalarOutputShape
(
feature_size
,
context
);
ScalarOutputShape
(
feature_size
+
1
,
context
);
return
tensorflow
::
Status
::
OK
();
})
.
Doc
(
R"doc(
Reads sentences and creates a beam parser.
...
...
@@ -112,6 +139,15 @@ REGISTER_OP("BeamParser")
.
Output
(
"alive: bool"
)
.
Attr
(
"feature_size: int"
)
.
SetIsStateful
()
.
SetShapeFn
([](
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
int
feature_size
;
TF_RETURN_IF_ERROR
(
context
->
GetAttr
(
"feature_size"
,
&
feature_size
));
for
(
int
i
=
0
;
i
<
feature_size
;
++
i
)
MatrixOutputShape
(
i
,
context
);
ScalarOutputShape
(
feature_size
,
context
);
VectorOutputShape
(
feature_size
+
1
,
context
);
TF_RETURN_IF_ERROR
(
ScalarInputShape
(
0
,
context
));
return
MatrixInputShape
(
1
,
context
);
})
.
Doc
(
R"doc(
Updates the beam parser based on scores in the input transition scores.
...
...
@@ -131,6 +167,13 @@ REGISTER_OP("BeamParserOutput")
.
Output
(
"gold_slot: int32"
)
.
Output
(
"path_scores: float"
)
.
SetIsStateful
()
.
SetShapeFn
([](
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
context
->
set_output
(
0
,
context
->
Matrix
(
2
,
context
->
UnknownDim
()));
context
->
set_output
(
1
,
context
->
Matrix
(
2
,
context
->
UnknownDim
()));
VectorOutputShape
(
2
,
context
);
VectorOutputShape
(
3
,
context
);
return
ScalarInputShape
(
0
,
context
);
})
.
Doc
(
R"doc(
Converts the current state of the beam parser into a set of indices into
the scoring matrices that lead there.
...
...
@@ -152,6 +195,11 @@ REGISTER_OP("BeamEvalOutput")
.
Output
(
"eval_metrics: int32"
)
.
Output
(
"documents: string"
)
.
SetIsStateful
()
.
SetShapeFn
([](
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
context
->
set_output
(
0
,
context
->
Vector
(
2
));
VectorOutputShape
(
1
,
context
);
return
ScalarInputShape
(
0
,
context
);
})
.
Doc
(
R"doc(
Computes eval metrics for the best paths in the input beams.
...
...
@@ -192,6 +240,13 @@ REGISTER_OP("FeatureSize")
.
Output
(
"embedding_dims: int32"
)
.
Output
(
"num_actions: int32"
)
.
Attr
(
"arg_prefix: string='brain_parser'"
)
.
SetShapeFn
([](
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
VectorOutputShape
(
0
,
context
);
VectorOutputShape
(
1
,
context
);
VectorOutputShape
(
2
,
context
);
ScalarOutputShape
(
3
,
context
);
return
tensorflow
::
Status
::
OK
();
})
.
Doc
(
R"doc(
An op that returns the number and domain sizes of parser features.
...
...
@@ -210,6 +265,10 @@ REGISTER_OP("FeatureVocab")
.
Attr
(
"arg_prefix: string='brain_parser'"
)
.
Attr
(
"embedding_name: string='words'"
)
.
Output
(
"vocab: string"
)
.
SetShapeFn
([](
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
VectorOutputShape
(
0
,
context
);
return
tensorflow
::
Status
::
OK
();
})
.
Doc
(
R"doc(
Returns the vocabulary of the parser features for a particular named channel.
For "words" this would would be the entire vocabulary, plus any special tokens
...
...
@@ -227,6 +286,12 @@ REGISTER_OP("UnpackSyntaxNetSparseFeatures")
.
Output
(
"indices: int32"
)
.
Output
(
"ids: int64"
)
.
Output
(
"weights: float"
)
.
SetShapeFn
([](
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
VectorOutputShape
(
0
,
context
);
VectorOutputShape
(
1
,
context
);
VectorOutputShape
(
2
,
context
);
return
VectorInputShape
(
0
,
context
);
})
.
Doc
(
R"doc(
Converts a vector of strings with SparseFeatures to tensors.
...
...
@@ -249,11 +314,16 @@ REGISTER_OP("WordEmbeddingInitializer")
.
Attr
(
"vectors: string"
)
.
Attr
(
"task_context: string = ''"
)
.
Attr
(
"vocabulary: string = ''"
)
.
Attr
(
"override_num_embeddings: int = -1"
)
.
Attr
(
"cache_vectors_locally: bool = true"
)
.
Attr
(
"num_special_embeddings: int = 3"
)
.
Attr
(
"embedding_init: float = 1.0"
)
.
Attr
(
"seed: int = 0"
)
.
Attr
(
"seed2: int = 0"
)
.
SetShapeFn
([](
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
MatrixOutputShape
(
0
,
context
);
return
tensorflow
::
Status
::
OK
();
})
.
Doc
(
R"doc(
Reads word embeddings from an sstable of dist_belief.TokenEmbedding protos for
every word specified in a text vocabulary file.
...
...
@@ -264,6 +334,10 @@ task_context: file path at which to read the task context, for its "word-map"
input. Exactly one of `task_context` or `vocabulary` must be specified.
vocabulary: path to vocabulary file, which contains one unique word per line, in
order. Exactly one of `task_context` or `vocabulary` must be specified.
override_num_embeddings: Number of rows in the returned embedding matrix. If
override_num_embeddings is larger than 0, then the returned embedding matrix
has override_num_embeddings_ rows. Otherwise, the number of rows of the
returned embedding matrix is |vocabulary| + num_special_embeddings.
cache_vectors_locally: Whether to cache the vectors file to a local temp file
before parsing it. This greatly reduces initialization time when the vectors
are stored remotely, but requires that "/tmp" has sufficient space.
...
...
@@ -286,6 +360,11 @@ REGISTER_OP("DocumentSource")
.
Attr
(
"corpus_name: string='documents'"
)
.
Attr
(
"batch_size: int"
)
.
SetIsStateful
()
.
SetShapeFn
([](
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
VectorOutputShape
(
0
,
context
);
ScalarOutputShape
(
1
,
context
);
return
tensorflow
::
Status
::
OK
();
})
.
Doc
(
R"doc(
Reads documents from documents_path and outputs them.
...
...
@@ -301,6 +380,9 @@ REGISTER_OP("DocumentSink")
.
Attr
(
"task_context: string=''"
)
.
Attr
(
"task_context_str: string=''"
)
.
Attr
(
"corpus_name: string='documents'"
)
.
SetShapeFn
([](
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
return
VectorInputShape
(
0
,
context
);
})
.
Doc
(
R"doc(
Write documents to documents_path.
...
...
@@ -312,6 +394,10 @@ task_context_str: a task context in text format, used if task_context is empty.
REGISTER_OP
(
"SegmenterTrainingDataConstructor"
)
.
Input
(
"documents: string"
)
.
Output
(
"char_doc: string"
)
.
SetShapeFn
([](
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
VectorOutputShape
(
0
,
context
);
return
VectorInputShape
(
0
,
context
);
})
.
Doc
(
R"doc(
Constructs segmentation training data from documents with gold segmentation.
...
...
@@ -322,6 +408,10 @@ char_doc: a vector of documents as serialized protos.
REGISTER_OP
(
"CharTokenGenerator"
)
.
Input
(
"documents: string"
)
.
Output
(
"char_doc: string"
)
.
SetShapeFn
([](
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
VectorOutputShape
(
0
,
context
);
return
VectorInputShape
(
0
,
context
);
})
.
Doc
(
R"doc(
Converts token field of the input documents such that each token in the
output doc is a utf-8 character from that doc's text.
...
...
@@ -337,6 +427,10 @@ REGISTER_OP("WellFormedFilter")
.
Attr
(
"task_context_str: string=''"
)
.
Attr
(
"corpus_name: string='documents'"
)
.
Attr
(
"keep_malformed_documents: bool = False"
)
.
SetShapeFn
([](
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
VectorOutputShape
(
0
,
context
);
return
VectorInputShape
(
0
,
context
);
})
.
Doc
(
R"doc(
Removes sentences with malformed parse trees, i.e. they contain cycles.
...
...
@@ -353,6 +447,10 @@ REGISTER_OP("ProjectivizeFilter")
.
Attr
(
"task_context_str: string=''"
)
.
Attr
(
"corpus_name: string='documents'"
)
.
Attr
(
"discard_non_projective: bool = False"
)
.
SetShapeFn
([](
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
VectorOutputShape
(
0
,
context
);
return
VectorInputShape
(
0
,
context
);
})
.
Doc
(
R"doc(
Modifies input parse trees to make them projective.
...
...
research/syntaxnet/syntaxnet/ops/shape_helpers.h
0 → 100644
View file @
80178fc6
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Shape inference functions for SyntaxNet ops.
#ifndef SYNTAXNET_OPS_SHAPE_HELPERS_H_
#define SYNTAXNET_OPS_SHAPE_HELPERS_H_
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace
syntaxnet
{
// Returns OK if the |input_index|'th input is a tensor of the |rank| with
// unknown dimensions.
inline
tensorflow
::
Status
TensorInputShape
(
int
input_index
,
int
rank
,
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
tensorflow
::
shape_inference
::
ShapeHandle
unused
;
return
context
->
WithRank
(
context
->
input
(
input_index
),
rank
,
&
unused
);
}
// Returns OK if the |input_index|'th input is a scalar.
inline
tensorflow
::
Status
ScalarInputShape
(
int
input_index
,
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
return
TensorInputShape
(
input_index
,
0
,
context
);
}
// Returns OK if the |input_index|'th input is a vector of unknown dimension.
inline
tensorflow
::
Status
VectorInputShape
(
int
input_index
,
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
return
TensorInputShape
(
input_index
,
1
,
context
);
}
// Returns OK if the |input_index|'th input is a matrix of unknown dimensions.
inline
tensorflow
::
Status
MatrixInputShape
(
int
input_index
,
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
return
TensorInputShape
(
input_index
,
2
,
context
);
}
// Sets the |output_index|'th output to a scalar.
inline
void
ScalarOutputShape
(
int
output_index
,
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
context
->
set_output
(
output_index
,
context
->
Scalar
());
}
// Sets the |output_index|'th output to a vector of unknown dimension.
inline
void
VectorOutputShape
(
int
output_index
,
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
context
->
set_output
(
output_index
,
context
->
UnknownShapeOfRank
(
1
));
}
// Sets the |output_index|'th output to a matrix of unknown dimensions.
inline
void
MatrixOutputShape
(
int
output_index
,
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
context
->
set_output
(
output_index
,
context
->
UnknownShapeOfRank
(
2
));
}
}
// namespace syntaxnet
#endif // SYNTAXNET_OPS_SHAPE_HELPERS_H_
research/syntaxnet/syntaxnet/parser_eval.py
View file @
80178fc6
...
...
@@ -19,6 +19,8 @@
import
os
import
os.path
import
time
from
absl
import
app
from
absl
import
flags
import
tempfile
import
tensorflow
as
tf
...
...
@@ -33,7 +35,6 @@ from syntaxnet import structured_graph_builder
from
syntaxnet.ops
import
gen_parser_ops
from
syntaxnet
import
task_spec_pb2
flags
=
tf
.
app
.
flags
FLAGS
=
flags
.
FLAGS
...
...
@@ -158,4 +159,4 @@ def main(unused_argv):
if
__name__
==
'__main__'
:
tf
.
app
.
run
()
app
.
run
(
main
)
research/syntaxnet/syntaxnet/parser_features.cc
View file @
80178fc6
...
...
@@ -331,24 +331,6 @@ class LastActionFeatureFunction : public ParserFeatureFunction {
REGISTER_PARSER_FEATURE_FUNCTION
(
"last-action"
,
LastActionFeatureFunction
);
class
Constant
:
public
ParserFeatureFunction
{
public:
void
Init
(
TaskContext
*
context
)
override
{
value_
=
this
->
GetIntParameter
(
"value"
,
0
);
this
->
set_feature_type
(
new
NumericFeatureType
(
this
->
name
(),
value_
+
1
));
}
// Returns the constant's value.
FeatureValue
Compute
(
const
WorkspaceSet
&
workspaces
,
const
ParserState
&
state
,
const
FeatureVector
*
result
)
const
override
{
return
value_
;
}
private:
int
value_
=
0
;
};
REGISTER_PARSER_FEATURE_FUNCTION
(
"constant"
,
Constant
);
// Register the generic parser features.
typedef
GenericFeatures
<
ParserState
>
GenericParserFeature
;
REGISTER_SYNTAXNET_GENERIC_FEATURES
(
GenericParserFeature
);
...
...
research/syntaxnet/syntaxnet/parser_trainer.py
View file @
80178fc6
...
...
@@ -20,6 +20,8 @@
import
os
import
os.path
import
time
from
absl
import
app
from
absl
import
flags
import
tensorflow
as
tf
from
tensorflow.python.platform
import
gfile
...
...
@@ -32,7 +34,6 @@ from syntaxnet import structured_graph_builder
from
syntaxnet.ops
import
gen_parser_ops
from
syntaxnet
import
task_spec_pb2
flags
=
tf
.
app
.
flags
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_string
(
'tf_master'
,
''
,
...
...
@@ -299,4 +300,4 @@ def main(unused_argv):
if
__name__
==
'__main__'
:
tf
.
app
.
run
()
app
.
run
(
main
)
research/syntaxnet/syntaxnet/reader_ops.cc
View file @
80178fc6
...
...
@@ -453,6 +453,8 @@ class WordEmbeddingInitializer : public OpKernel {
&
cache_vectors_locally_
));
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"num_special_embeddings"
,
&
num_special_embeddings_
));
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"override_num_embeddings"
,
&
override_num_embeddings_
));
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"embedding_init"
,
&
embedding_init_
));
...
...
@@ -569,7 +571,13 @@ class WordEmbeddingInitializer : public OpKernel {
const
std
::
unordered_map
<
string
,
int64
>
&
vocabulary
,
const
TokenEmbedding
&
embedding
,
OpKernelContext
*
context
,
Tensor
**
embedding_matrix
)
const
{
const
int
rows
=
vocabulary
.
size
()
+
num_special_embeddings_
;
const
int
rows
=
override_num_embeddings_
>
0
?
override_num_embeddings_
:
(
vocabulary
.
size
()
+
num_special_embeddings_
);
if
(
rows
<
vocabulary
.
size
())
{
return
InvalidArgument
(
"Embedding matrix row number "
,
rows
,
" is less than vocabulary size "
,
vocabulary
.
size
());
}
const
int
columns
=
embedding
.
vector
().
values_size
();
TF_RETURN_IF_ERROR
(
context
->
allocate_output
(
0
,
TensorShape
({
rows
,
columns
}),
embedding_matrix
));
...
...
@@ -637,6 +645,11 @@ class WordEmbeddingInitializer : public OpKernel {
// Number of special embeddings to allocate.
int
num_special_embeddings_
=
3
;
// If override_num_embeddings_ is larger than zero, then the returned
// embedding matrix has override_num_embeddings_ of rows. Otherwise, the
// number of rows equals to |vocabulary| + num_special_embeddigs_.
int
override_num_embeddings_
=
-
1
;
// Seed for random initialization.
uint64
seed_
=
0
;
...
...
research/syntaxnet/syntaxnet/reader_ops_test.py
View file @
80178fc6
...
...
@@ -20,35 +20,27 @@ import os.path
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow.python.framework
import
test_util
from
tensorflow.python.platform
import
googletest
from
tensorflow.python.platform
import
tf_logging
as
logging
from
syntaxnet
import
dictionary_pb2
from
syntaxnet
import
graph_builder
from
syntaxnet
import
sparse_pb2
from
syntaxnet
import
test_flags
from
syntaxnet.ops
import
gen_parser_ops
FLAGS
=
tf
.
app
.
flags
.
FLAGS
if
not
hasattr
(
FLAGS
,
'test_srcdir'
):
FLAGS
.
test_srcdir
=
''
if
not
hasattr
(
FLAGS
,
'test_tmpdir'
):
FLAGS
.
test_tmpdir
=
tf
.
test
.
get_temp_dir
()
class
ParsingReaderOpsTest
(
test_util
.
TensorFlowTestCase
):
class
ParsingReaderOpsTest
(
tf
.
test
.
TestCase
):
def
setUp
(
self
):
# Creates a task context with the correct testing paths.
initial_task_context
=
os
.
path
.
join
(
FLAGS
.
test_srcdir
,
initial_task_context
=
os
.
path
.
join
(
test_flags
.
source_root
()
,
'syntaxnet/'
'testdata/context.pbtxt'
)
self
.
_task_context
=
os
.
path
.
join
(
FLAGS
.
test_t
mpdir
,
'context.pbtxt'
)
self
.
_task_context
=
os
.
path
.
join
(
test_flags
.
te
mp
_
dir
()
,
'context.pbtxt'
)
with
open
(
initial_task_context
,
'r'
)
as
fin
:
with
open
(
self
.
_task_context
,
'w'
)
as
fout
:
fout
.
write
(
fin
.
read
().
replace
(
'SRCDIR'
,
FLAGS
.
test_srcdir
)
.
replace
(
'OUTPATH'
,
FLAGS
.
test_t
mpdir
))
fout
.
write
(
fin
.
read
().
replace
(
'SRCDIR'
,
test_flags
.
source_root
()
)
.
replace
(
'OUTPATH'
,
test_flags
.
te
mp
_
dir
()
))
# Creates necessary term maps.
with
self
.
test_session
()
as
sess
:
...
...
@@ -175,7 +167,7 @@ class ParsingReaderOpsTest(test_util.TensorFlowTestCase):
def
testWordEmbeddingInitializer
(
self
):
# Provide embeddings for the first three words in the word map.
records_path
=
os
.
path
.
join
(
FLAGS
.
test_t
mpdir
,
'records1'
)
records_path
=
os
.
path
.
join
(
test_flags
.
te
mp
_
dir
()
,
'records1'
)
writer
=
tf
.
python_io
.
TFRecordWriter
(
records_path
)
writer
.
write
(
self
.
_token_embedding
(
'.'
,
[
1
,
2
]))
writer
.
write
(
self
.
_token_embedding
(
','
,
[
3
,
4
]))
...
...
@@ -193,7 +185,7 @@ class ParsingReaderOpsTest(test_util.TensorFlowTestCase):
embeddings
[:
3
,])
def
testWordEmbeddingInitializerRepeatability
(
self
):
records_path
=
os
.
path
.
join
(
FLAGS
.
test_t
mpdir
,
'records2'
)
records_path
=
os
.
path
.
join
(
test_flags
.
te
mp
_
dir
()
,
'records2'
)
writer
=
tf
.
python_io
.
TFRecordWriter
(
records_path
)
writer
.
write
(
self
.
_token_embedding
(
'.'
,
[
1
,
2
,
3
]))
# 3 dims
del
writer
...
...
@@ -234,7 +226,7 @@ class ParsingReaderOpsTest(test_util.TensorFlowTestCase):
vocabulary
=
'/dev/null'
).
eval
()
def
testWordEmbeddingInitializerVocabularyFile
(
self
):
records_path
=
os
.
path
.
join
(
FLAGS
.
test_t
mpdir
,
'records3'
)
records_path
=
os
.
path
.
join
(
test_flags
.
te
mp
_
dir
()
,
'records3'
)
writer
=
tf
.
python_io
.
TFRecordWriter
(
records_path
)
writer
.
write
(
self
.
_token_embedding
(
'a'
,
[
1
,
2
,
3
]))
writer
.
write
(
self
.
_token_embedding
(
'b'
,
[
2
,
3
,
4
]))
...
...
@@ -243,7 +235,7 @@ class ParsingReaderOpsTest(test_util.TensorFlowTestCase):
writer
.
write
(
self
.
_token_embedding
(
'e'
,
[
5
,
6
,
7
]))
del
writer
vocabulary_path
=
os
.
path
.
join
(
FLAGS
.
test_t
mpdir
,
'vocabulary3'
)
vocabulary_path
=
os
.
path
.
join
(
test_flags
.
te
mp
_
dir
()
,
'vocabulary3'
)
with
open
(
vocabulary_path
,
'w'
)
as
vocabulary_file
:
vocabulary_file
.
write
(
'a
\n
c
\n
e
\n
x
\n
'
)
# 'x' not in pretrained embeddings
...
...
@@ -271,8 +263,50 @@ class ParsingReaderOpsTest(test_util.TensorFlowTestCase):
[
5.0
/
norm_e
,
6.0
/
norm_e
,
7.0
/
norm_e
]],
embeddings
[:
3
].
eval
())
def
testWordEmbeddingInitializerPresetRowNumber
(
self
):
records_path
=
os
.
path
.
join
(
test_flags
.
temp_dir
(),
'records3'
)
writer
=
tf
.
python_io
.
TFRecordWriter
(
records_path
)
writer
.
write
(
self
.
_token_embedding
(
'a'
,
[
1
,
2
,
3
]))
writer
.
write
(
self
.
_token_embedding
(
'b'
,
[
2
,
3
,
4
]))
writer
.
write
(
self
.
_token_embedding
(
'c'
,
[
3
,
4
,
5
]))
writer
.
write
(
self
.
_token_embedding
(
'd'
,
[
4
,
5
,
6
]))
writer
.
write
(
self
.
_token_embedding
(
'e'
,
[
5
,
6
,
7
]))
del
writer
vocabulary_path
=
os
.
path
.
join
(
test_flags
.
temp_dir
(),
'vocabulary3'
)
with
open
(
vocabulary_path
,
'w'
)
as
vocabulary_file
:
vocabulary_file
.
write
(
'a
\n
c
\n
e
\n
x
\n
'
)
# 'x' not in pretrained embeddings
# Enumerate a variety of configurations.
for
cache_vectors_locally
in
[
False
,
True
]:
for
num_special_embeddings
in
[
None
,
1
,
2
,
5
]:
# None = use default of 3
for
override_num_embeddings
in
[
-
1
,
8
,
10
]:
with
self
.
test_session
():
embeddings
=
gen_parser_ops
.
word_embedding_initializer
(
vectors
=
records_path
,
vocabulary
=
vocabulary_path
,
override_num_embeddings
=
override_num_embeddings
,
cache_vectors_locally
=
cache_vectors_locally
,
num_special_embeddings
=
num_special_embeddings
)
# Expect 4 embeddings from the vocabulary plus special embeddings.
expected_num_embeddings
=
4
+
(
num_special_embeddings
or
3
)
if
override_num_embeddings
>
0
:
expected_num_embeddings
=
override_num_embeddings
self
.
assertAllEqual
([
expected_num_embeddings
,
3
],
tf
.
shape
(
embeddings
).
eval
())
# The first 3 embeddings should be pretrained.
norm_a
=
(
1.0
+
4.0
+
9.0
)
**
0.5
norm_c
=
(
9.0
+
16.0
+
25.0
)
**
0.5
norm_e
=
(
25.0
+
36.0
+
49.0
)
**
0.5
self
.
assertAllClose
([[
1.0
/
norm_a
,
2.0
/
norm_a
,
3.0
/
norm_a
],
[
3.0
/
norm_c
,
4.0
/
norm_c
,
5.0
/
norm_c
],
[
5.0
/
norm_e
,
6.0
/
norm_e
,
7.0
/
norm_e
]],
embeddings
[:
3
].
eval
())
def
testWordEmbeddingInitializerVocabularyFileWithDuplicates
(
self
):
records_path
=
os
.
path
.
join
(
FLAGS
.
test_t
mpdir
,
'records4'
)
records_path
=
os
.
path
.
join
(
test_flags
.
te
mp
_
dir
()
,
'records4'
)
writer
=
tf
.
python_io
.
TFRecordWriter
(
records_path
)
writer
.
write
(
self
.
_token_embedding
(
'a'
,
[
1
,
2
,
3
]))
writer
.
write
(
self
.
_token_embedding
(
'b'
,
[
2
,
3
,
4
]))
...
...
@@ -281,7 +315,7 @@ class ParsingReaderOpsTest(test_util.TensorFlowTestCase):
writer
.
write
(
self
.
_token_embedding
(
'e'
,
[
5
,
6
,
7
]))
del
writer
vocabulary_path
=
os
.
path
.
join
(
FLAGS
.
test_t
mpdir
,
'vocabulary4'
)
vocabulary_path
=
os
.
path
.
join
(
test_flags
.
te
mp
_
dir
()
,
'vocabulary4'
)
with
open
(
vocabulary_path
,
'w'
)
as
vocabulary_file
:
vocabulary_file
.
write
(
'a
\n
c
\n
e
\n
x
\n
y
\n
x'
)
# 'x' duplicated
...
...
@@ -292,4 +326,4 @@ class ParsingReaderOpsTest(test_util.TensorFlowTestCase):
if
__name__
==
'__main__'
:
google
test
.
main
()
tf
.
test
.
main
()
research/syntaxnet/syntaxnet/registry.cc
View file @
80178fc6
...
...
@@ -15,6 +15,12 @@ limitations under the License.
#include "syntaxnet/registry.h"
#include <set>
#include <string>
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace
syntaxnet
{
// Global list of all component registries.
...
...
@@ -25,4 +31,35 @@ void RegistryMetadata::Register(RegistryMetadata *registry) {
global_registry_list
=
registry
;
}
string
ComponentMetadata
::
DebugString
()
const
{
return
tensorflow
::
strings
::
StrCat
(
"Registered '"
,
name_
,
"' as class "
,
class_name_
,
" at "
,
file_
,
":"
,
line_
);
}
tensorflow
::
Status
RegistryMetadata
::
Validate
()
{
static
const
tensorflow
::
Status
*
const
status
=
new
tensorflow
::
Status
(
ValidateImpl
());
return
*
status
;
}
tensorflow
::
Status
RegistryMetadata
::
ValidateImpl
()
{
// Iterates over the registries for each type.
for
(
RegistryMetadata
*
registry
=
global_registry_list
;
registry
!=
nullptr
;
registry
=
static_cast
<
RegistryMetadata
*>
(
registry
->
link
()))
{
std
::
set
<
string
>
names
;
// Searches for duplicate names within each component registry.
for
(
ComponentMetadata
*
component
=
*
(
registry
->
components_
);
component
!=
nullptr
;
component
=
component
->
link
())
{
if
(
!
names
.
insert
(
component
->
name
()).
second
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Multiple classes named '"
,
component
->
name
(),
"' have been registered as "
,
registry
->
name
(),
": "
,
component
->
DebugString
());
}
}
}
return
tensorflow
::
Status
::
OK
();
}
}
// namespace syntaxnet
research/syntaxnet/syntaxnet/registry.h
View file @
80178fc6
...
...
@@ -54,10 +54,13 @@ limitations under the License.
#define SYNTAXNET_REGISTRY_H_
#include <string.h>
#include <memory>
#include <string>
#include <vector>
#include "syntaxnet/utils.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
namespace
syntaxnet
{
...
...
@@ -75,6 +78,9 @@ class ComponentMetadata {
// Returns component name.
const
char
*
name
()
const
{
return
name_
;
}
// Returns a human-readable description of this.
string
DebugString
()
const
;
// Metadata objects can be linked in a list.
ComponentMetadata
*
link
()
const
{
return
link_
;
}
void
set_link
(
ComponentMetadata
*
link
)
{
link_
=
link
;
}
...
...
@@ -107,7 +113,16 @@ class RegistryMetadata : public ComponentMetadata {
// Registers a component registry in the master registry.
static
void
Register
(
RegistryMetadata
*
registry
);
// Validates the registry; returns non-OK if there are duplicate component
// names of the same type. Situations where this can happen include accidental
// class name collisions, and linking in two different multiarch versions
// of the same component. Repeated calls uses the original result.
static
tensorflow
::
Status
Validate
();
private:
// Implementation for validating the registry.
static
tensorflow
::
Status
ValidateImpl
();
// Location of list of components in registry.
ComponentMetadata
**
components_
;
};
...
...
@@ -157,14 +172,21 @@ struct ComponentRegistry {
T
*
object_
;
};
// Finds registrar for named component in registry.
const
Registrar
*
GetComponent
(
const
char
*
type
)
const
{
// Finds registrar for named component in registry, returning null if not
// found.
const
Registrar
*
GetComponentOrNull
(
const
char
*
type
)
const
{
Registrar
*
r
=
components
;
while
(
r
!=
nullptr
&&
strcmp
(
type
,
r
->
type
())
!=
0
)
r
=
r
->
next
();
if
(
r
==
nullptr
)
{
return
r
;
}
// Finds registrar for named component in registry, raising errors on failure.
const
Registrar
*
GetComponent
(
const
char
*
type
)
const
{
const
Registrar
*
result
=
GetComponentOrNull
(
type
);
if
(
result
==
nullptr
)
{
LOG
(
FATAL
)
<<
"Unknown "
<<
name
<<
" component: '"
<<
type
<<
"'."
;
}
return
r
;
return
r
esult
;
}
// Finds a named component in the registry.
...
...
@@ -196,7 +218,24 @@ class RegisterableClass {
typedef
ComponentRegistry
<
Factory
>
Registry
;
// Creates a new component instance.
static
T
*
Create
(
const
string
&
type
)
{
return
registry
()
->
Lookup
(
type
)();
}
static
T
*
Create
(
const
string
&
type
)
{
TF_CHECK_OK
(
syntaxnet
::
RegistryMetadata
::
Validate
());
return
registry
()
->
Lookup
(
type
)();
}
static
tensorflow
::
Status
CreateOrError
(
const
string
&
type
,
std
::
unique_ptr
<
T
>
*
result
)
{
TF_RETURN_IF_ERROR
(
syntaxnet
::
RegistryMetadata
::
Validate
());
const
typename
Registry
::
Registrar
*
component
=
registry
()
->
GetComponentOrNull
(
type
.
c_str
());
if
(
component
==
nullptr
)
{
return
tensorflow
::
errors
::
NotFound
(
"Unknown "
,
registry
()
->
name
,
": "
,
type
);
}
else
{
result
->
reset
(
component
->
object
()());
return
tensorflow
::
Status
::
OK
();
}
}
// Returns registry for class.
static
Registry
*
registry
()
{
return
&
registry_
;
}
...
...
research/syntaxnet/syntaxnet/registry_test.cc
0 → 100644
View file @
80178fc6
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "syntaxnet/registry.h"
#include <memory>
#include "dragnn/core/test/generic.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
class
ThingDoer
:
public
RegisterableClass
<
ThingDoer
>
{};
DECLARE_SYNTAXNET_CLASS_REGISTRY
(
"Thing doer"
,
ThingDoer
);
REGISTER_SYNTAXNET_CLASS_REGISTRY
(
"Thing doer"
,
ThingDoer
);
class
Foo
:
public
ThingDoer
{};
class
Bar
:
public
ThingDoer
{};
class
Bar2
:
public
ThingDoer
{};
REGISTER_SYNTAXNET_CLASS_COMPONENT
(
ThingDoer
,
"foo"
,
Foo
);
REGISTER_SYNTAXNET_CLASS_COMPONENT
(
ThingDoer
,
"bar"
,
Bar
);
#if DRAGNN_REGISTRY_TEST_WITH_DUPLICATE
REGISTER_SYNTAXNET_CLASS_COMPONENT
(
ThingDoer
,
"bar"
,
Bar2
);
// bad
constexpr
char
kDuplicateError
[]
=
"Multiple classes named 'bar' have been registered as Thing doer"
;
#endif
namespace
{
#if !DRAGNN_REGISTRY_TEST_WITH_DUPLICATE
// Tests that CreateOrError() is successful for a properly registered component.
TEST
(
RegistryTest
,
CreateOrErrorSuccess
)
{
std
::
unique_ptr
<
ThingDoer
>
object
;
TF_ASSERT_OK
(
ThingDoer
::
CreateOrError
(
"foo"
,
&
object
));
ASSERT_NE
(
object
,
nullptr
);
}
#else
// Tests that CreateOrError() fails if the registry is misconfigured.
TEST
(
RegistryTest
,
CreateOrErrorFailure
)
{
std
::
unique_ptr
<
ThingDoer
>
object
;
EXPECT_THAT
(
ThingDoer
::
CreateOrError
(
"bar"
,
&
object
),
test
::
IsErrorWithSubstr
(
kDuplicateError
));
ASSERT_EQ
(
object
,
nullptr
);
// Any call to Create has the same error.
EXPECT_THAT
(
ThingDoer
::
CreateOrError
(
"foo"
,
&
object
),
test
::
IsErrorWithSubstr
(
kDuplicateError
));
}
// Tests that Create() dies if the registry is misconfigured.
TEST
(
RegistryTest
,
CreateFailure
)
{
EXPECT_DEATH
(
ThingDoer
::
Create
(
"bar"
),
kDuplicateError
);
}
#endif
// Tests that CreateOrError() returns error if the component is unknown.
TEST
(
RegistryTest
,
CreateOrErrorUnknown
)
{
std
::
unique_ptr
<
ThingDoer
>
object
;
EXPECT_FALSE
(
ThingDoer
::
CreateOrError
(
"unknown"
,
&
object
).
ok
());
}
// Tests that Validate() returns OK only when the registry is fine.
TEST
(
RegistryTest
,
Validate
)
{
#if DRAGNN_REGISTRY_TEST_WITH_DUPLICATE
EXPECT_THAT
(
RegistryMetadata
::
Validate
(),
test
::
IsErrorWithSubstr
(
kDuplicateError
));
#else
TF_EXPECT_OK
(
RegistryMetadata
::
Validate
());
#endif
}
}
// namespace
}
// namespace syntaxnet
research/syntaxnet/syntaxnet/shared_store.h
View file @
80178fc6
...
...
@@ -39,7 +39,7 @@ class SharedStore {
static
const
T
*
Get
(
const
string
&
name
,
Args
&&
...
args
);
// NOLINT(build/c++11)
// Like Get(), but creates the object with "closure
->Run
()". If the closure
// Like Get(), but creates the object with "
(*
closure
)
()". If the closure
// returns null, we store a null in the SharedStore, but note that Release()
// cannot be used to remove it. This is because Release() finds the object
// by associative lookup, and there may be more than one null value, so we
...
...
research/syntaxnet/syntaxnet/structured_graph_builder.py
View file @
80178fc6
...
...
@@ -115,9 +115,8 @@ class StructuredGraphBuilder(graph_builder.GreedyParser):
return
tf
.
logical_and
(
args
[
1
]
<
max_steps
,
tf
.
reduce_any
(
args
[
3
]))
step
=
tf
.
constant
(
0
,
tf
.
int32
,
[])
scores_array
=
tensor_array_ops
.
TensorArray
(
dtype
=
tf
.
float32
,
size
=
0
,
dynamic_size
=
True
)
scores_array
=
tensor_array_ops
.
TensorArray
(
dtype
=
tf
.
float32
,
size
=
0
,
infer_shape
=
False
,
dynamic_size
=
True
)
alive
=
tf
.
constant
(
True
,
tf
.
bool
,
[
batch_size
])
alive_steps
=
tf
.
constant
(
0
,
tf
.
int32
,
[
batch_size
])
t
=
tf
.
while_loop
(
...
...
research/syntaxnet/syntaxnet/syntaxnet.bzl
View file @
80178fc6
...
...
@@ -12,99 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
load
(
"@protobuf_archive//:protobuf.bzl"
,
"cc_proto_library"
)
load
(
"@protobuf_archive//:protobuf.bzl"
,
"py_proto_library"
)
def
if_cuda
(
if_true
,
if_false
=
[]):
"""Shorthand for select()'ing on whether we're building with CUDA."""
return
select
({
"@local_config_cuda//cuda:using_nvcc"
:
if_true
,
"@local_config_cuda//cuda:using_clang"
:
if_true
,
"//conditions:default"
:
if_false
})
def
tf_copts
():
return
([
"-fno-exceptions"
,
"-DEIGEN_AVOID_STL_ARRAY"
,]
+
if_cuda
([
"-DGOOGLE_CUDA=1"
])
+
select
({
"@org_tensorflow//tensorflow:darwin"
:
[],
"//conditions:default"
:
[
"-pthread"
]}))
def
tf_proto_library
(
name
,
srcs
=
[],
has_services
=
False
,
deps
=
[],
visibility
=
None
,
testonly
=
0
,
cc_api_version
=
2
,
go_api_version
=
2
,
java_api_version
=
2
,
py_api_version
=
2
):
native
.
filegroup
(
name
=
name
+
"_proto_srcs"
,
srcs
=
srcs
,
testonly
=
testonly
,)
cc_proto_library
(
name
=
name
,
srcs
=
srcs
,
deps
=
deps
,
cc_libs
=
[
"@protobuf_archive//:protobuf"
],
protoc
=
"@protobuf_archive//:protoc"
,
default_runtime
=
"@protobuf_archive//:protobuf"
,
testonly
=
testonly
,
visibility
=
visibility
,)
def
tf_proto_library_py
(
name
,
srcs
=
[],
deps
=
[],
visibility
=
None
,
testonly
=
0
):
py_proto_library
(
name
=
name
,
srcs
=
srcs
,
srcs_version
=
"PY2AND3"
,
deps
=
deps
,
default_runtime
=
"@protobuf_archive//:protobuf_python"
,
protoc
=
"@protobuf_archive//:protoc"
,
visibility
=
visibility
,
testonly
=
testonly
,)
# Given a list of "op_lib_names" (a list of files in the ops directory
# without their .cc extensions), generate a library for that file.
def
tf_gen_op_libs
(
op_lib_names
):
# Make library out of each op so it can also be used to generate wrappers
# for various languages.
for
n
in
op_lib_names
:
native
.
cc_library
(
name
=
n
+
"_op_lib"
,
copts
=
tf_copts
(),
srcs
=
[
"ops/"
+
n
+
".cc"
],
deps
=
([
"@org_tensorflow//tensorflow/core:framework"
]),
visibility
=
[
"//visibility:public"
],
alwayslink
=
1
,
linkstatic
=
1
,)
# Invoke this rule in .../tensorflow/python to build the wrapper library.
def
tf_gen_op_wrapper_py
(
name
,
out
=
None
,
hidden
=
[],
visibility
=
None
,
deps
=
[],
require_shape_functions
=
False
):
# Construct a cc_binary containing the specified ops.
tool_name
=
"gen_"
+
name
+
"_py_wrappers_cc"
if
not
deps
:
deps
=
[
"//tensorflow/core:"
+
name
+
"_op_lib"
]
native
.
cc_binary
(
name
=
tool_name
,
linkopts
=
[
"-lm"
],
copts
=
tf_copts
(),
linkstatic
=
1
,
# Faster to link this one-time-use binary dynamically
deps
=
([
"@org_tensorflow//tensorflow/core:framework"
,
"@org_tensorflow//tensorflow/python:python_op_gen_main"
]
+
deps
),
)
# Invoke the previous cc_binary to generate a python file.
if
not
out
:
out
=
"ops/gen_"
+
name
+
".py"
native
.
genrule
(
name
=
name
+
"_pygenrule"
,
outs
=
[
out
],
tools
=
[
tool_name
],
cmd
=
(
"$(location "
+
tool_name
+
") "
+
","
.
join
(
hidden
)
+
" "
+
(
"1"
if
require_shape_functions
else
"0"
)
+
" > $@"
))
# Make a py_library out of the generated python file.
native
.
py_library
(
name
=
name
,
srcs
=
[
out
],
srcs_version
=
"PY2AND3"
,
visibility
=
visibility
,
deps
=
[
"@org_tensorflow//tensorflow/python:framework_for_generated_wrappers"
,
],)
"""Build rules for Syntaxnet."""
load
(
"@org_tensorflow//tensorflow/core:platform/default/build_config.bzl"
,
orig_tf_proto_library_cc
=
"tf_proto_library_cc"
,
)
load
(
"@org_tensorflow//tensorflow/core:platform/default/build_config.bzl"
,
orig_tf_proto_library_py
=
"tf_proto_library_py"
,
)
# For some reason, tf_proto_library_cc() isn't obeying the default_visibility
# directive at the top of the build file. So just set it to public (which it is
# anyway).
def
tf_proto_library_cc
(
name
,
visibility
=
[],
**
kwargs
):
visibility
=
visibility
if
visibility
else
[
"//visibility:public"
]
return
orig_tf_proto_library_cc
(
name
,
visibility
=
visibility
,
**
kwargs
)
def
tf_proto_library_py
(
name
,
visibility
=
[],
**
kwargs
):
visibility
=
visibility
if
visibility
else
[
"//visibility:public"
]
return
orig_tf_proto_library_py
(
name
,
visibility
=
visibility
,
**
kwargs
)
research/syntaxnet/syntaxnet/term_frequency_map.cc
View file @
80178fc6
...
...
@@ -19,6 +19,7 @@ limitations under the License.
#include <algorithm>
#include <limits>
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/io/buffered_inputstream.h"
#include "tensorflow/core/lib/io/random_inputstream.h"
...
...
@@ -52,8 +53,9 @@ void TermFrequencyMap::Clear() {
term_data_
.
clear
();
}
void
TermFrequencyMap
::
Load
(
const
string
&
filename
,
int
min_frequency
,
int
max_num_terms
)
{
tensorflow
::
Status
TermFrequencyMap
::
TryLoad
(
const
string
&
filename
,
int
min_frequency
,
int
max_num_terms
)
{
Clear
();
// If max_num_terms is non-positive, replace it with INT_MAX.
...
...
@@ -61,46 +63,83 @@ void TermFrequencyMap::Load(const string &filename, int min_frequency,
// Read the first line (total # of terms in the mapping).
std
::
unique_ptr
<
tensorflow
::
RandomAccessFile
>
file
;
TF_CHECK_OK
(
tensorflow
::
Env
::
Default
()
->
NewRandomAccessFile
(
filename
,
&
file
));
TF_RETURN_IF_ERROR
(
tensorflow
::
Env
::
Default
()
->
NewRandomAccessFile
(
filename
,
&
file
));
static
const
int
kInputBufferSize
=
1
*
1024
*
1024
;
/* bytes */
tensorflow
::
io
::
RandomAccessInputStream
stream
(
file
.
get
());
tensorflow
::
io
::
BufferedInputStream
buffer
(
&
stream
,
kInputBufferSize
);
string
line
;
TF_
CHECK_OK
(
buffer
.
ReadLine
(
&
line
));
TF_
RETURN_IF_ERROR
(
buffer
.
ReadLine
(
&
line
));
int32
total
=
-
1
;
CHECK
(
utils
::
ParseInt32
(
line
.
c_str
(),
&
total
))
<<
"Unable to parse from "
<<
filename
;
CHECK_GE
(
total
,
0
);
if
(
!
utils
::
ParseInt32
(
line
.
c_str
(),
&
total
))
{
return
tensorflow
::
errors
::
InvalidArgument
(
filename
,
":0: Unable to parse term map size"
);
}
if
(
total
<
0
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
filename
,
":0: Invalid term map size: "
,
total
);
}
// Read the mapping.
int64
last_frequency
=
-
1
;
for
(
int
i
=
0
;
i
<
total
&&
i
<
max_num_terms
;
++
i
)
{
TF_CHECK_OK
(
buffer
.
ReadLine
(
&
line
));
TF_RETURN_IF_ERROR
(
buffer
.
ReadLine
(
&
line
));
static
LazyRE2
re
=
{
"(.*) (
\\
d*)"
};
string
term
;
int64
frequency
=
0
;
CHECK
(
RE2
::
FullMatch
(
line
,
"(.*) (
\\
d*)"
,
&
term
,
&
frequency
));
CHECK
(
!
term
.
empty
());
CHECK_GT
(
frequency
,
0
);
if
(
!
RE2
::
FullMatch
(
line
,
*
re
,
&
term
,
&
frequency
))
{
return
tensorflow
::
errors
::
InvalidArgument
(
filename
,
":"
,
i
+
1
,
": Couldn't split term and frequency in line: "
,
line
);
}
if
(
term
.
empty
())
{
return
tensorflow
::
errors
::
InvalidArgument
(
filename
,
":"
,
i
+
1
,
": Invalid empty term"
);
}
if
(
frequency
<=
0
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
filename
,
":"
,
i
+
1
,
": Invalid frequency: term="
,
term
,
" frequency="
,
frequency
);
}
// Check frequency sorting (descending order).
if
(
i
>
0
)
CHECK_GE
(
last_frequency
,
frequency
);
if
(
i
>
0
&&
last_frequency
<
frequency
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
filename
,
":"
,
i
+
1
,
": Non-descending frequencies: current="
,
frequency
,
" previous="
,
last_frequency
);
}
last_frequency
=
frequency
;
// Ignore low-frequency items.
if
(
frequency
<
min_frequency
)
continue
;
// Check uniqueness of the mapped terms.
CHECK
(
term_index_
.
find
(
term
)
==
term_index_
.
end
())
<<
"File "
<<
filename
<<
" has duplicate term: "
<<
term
;
if
(
term_index_
.
find
(
term
)
!=
term_index_
.
end
())
{
return
tensorflow
::
errors
::
InvalidArgument
(
filename
,
":"
,
i
+
1
,
": Duplicate term: "
,
term
);
}
// Assign the next available index.
const
int
index
=
term_index_
.
size
();
term_index_
[
term
]
=
index
;
term_data_
.
push_back
(
std
::
pair
<
string
,
int64
>
(
term
,
frequency
));
}
CHECK_EQ
(
term_index_
.
size
(),
term_data_
.
size
());
if
(
term_index_
.
size
()
!=
term_data_
.
size
())
{
return
tensorflow
::
errors
::
Internal
(
"Unexpected size mismatch between term index ("
,
term_index_
.
size
(),
") and term data ("
,
term_data_
.
size
(),
")"
);
}
LOG
(
INFO
)
<<
"Loaded "
<<
term_index_
.
size
()
<<
" terms from "
<<
filename
<<
"."
;
return
tensorflow
::
Status
::
OK
();
}
void
TermFrequencyMap
::
Load
(
const
string
&
filename
,
int
min_frequency
,
int
max_num_terms
)
{
TF_CHECK_OK
(
TryLoad
(
filename
,
min_frequency
,
max_num_terms
));
}
struct
TermFrequencyMap
::
SortByFrequencyThenTerm
{
...
...
Prev
1
…
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment