Commit 277f99c7 authored by Ivan Bogatyy's avatar Ivan Bogatyy Committed by GitHub
Browse files

Merge pull request #1243 from bogatyy/master

Add license headers, fix some macOS issues
parents f7cea8d0 ea3fa4a3
. PUNCT
CC CONJ
DT DET
NN NOUN
NNS NOUN
PRP PRON
VBP VERB
10
. 5
books 4
They 3
I 2
buy 2
have 2
no 2
sell 2
and 1
clue 1
target {
name: "joint"
component_weights: [0, # lengths
0, # bilstm
1, # tagger
0, # heads
0, # modifiers
0, # digraph
1, # parser
0, # parsed_head_tokens
0, # parsed_heads
0, # parsed_modifiers
0, # labels
1] # labeler
unroll_using_oracle: [true, true, true, true, true, true,
true, true, true, true, true, true]
}
1 They they PRON PRP Case=Nom|Number=Plur 2 nsubj _ _
2 buy buy VERB VBP Number=Plur|Person=3|Tense=Pres 0 ROOT _ _
3 books book NOUN NNS Number=Plur 2 obj _ SpaceAfter=No
4 . . PUNCT . _ 2 punct _ _
1 They they PRON PRP Case=Nom|Number=Plur 2 nsubj _ _
2 sell sell VERB VBP Number=Plur|Person=3|Tense=Pres 0 ROOT _ _
3 books book NOUN NNS Number=Plur 2 obj _ SpaceAfter=No
4 . . PUNCT . _ 2 punct _ _
1 They they PRON PRP Case=Nom|Number=Plur 2 nsubj _ _
2 buy buy VERB VBP Number=Plur|Person=3|Tense=Pres 0 ROOT _ _
3 and and CONJ CC _ 4 cc _ _
4 sell sell VERB VBP Number=Plur|Person=3|Tense=Pres 2 conj _ _
5 books book NOUN NNS Number=Plur 2 obj _ SpaceAfter=No
6 . . PUNCT . _ 2 punct _ _
1 I I PRON PRP Case=Nom|Number=Sing|Person=1 2 nsubj _ _
2 have have VERB VBP Number=Sing|Person=1|Tense=Pres 0 ROOT _ _
3 no no DET DT PronType=Neg 4 det _ _
4 clue clue NOUN NN Number=Sing 2 obj _ SpaceAfter=No
5 . . PUNCT . _ 2 punct _ _
1 I I PRON PRP Case=Nom|Number=Sing|Person=1 2 nsubj _ _
2 have have VERB VBP Number=Sing|Person=1|Tense=Pres 0 ROOT _ _
3 no no DET DT PronType=Neg 4 det _ _
4 books book NOUN NNS Number=Plur 2 obj _ SpaceAfter=No
5 . . PUNCT . _ 2 punct _ _
#!/bin/bash #!/bin/bash
# Copyright 2017 Google Inc. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# This file builds all of the Javascript into a minified "hermetic" bundle.js # This file builds all of the Javascript into a minified "hermetic" bundle.js
# file, which is written out into the same directory as this script. # file, which is written out into the same directory as this script.
# #
......
#!/bin/bash #!/bin/bash
# Copyright 2017 Google Inc. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# This file starts up a development server, using webpack in development mode. # This file starts up a development server, using webpack in development mode.
# It takes no arguments. See README.md for more information. # It takes no arguments. See README.md for more information.
......
...@@ -178,9 +178,9 @@ cc_library( ...@@ -178,9 +178,9 @@ cc_library(
srcs = ["char_ngram_string_extractor.cc"], srcs = ["char_ngram_string_extractor.cc"],
hdrs = ["char_ngram_string_extractor.h"], hdrs = ["char_ngram_string_extractor.h"],
deps = [ deps = [
":base",
":segmenter_utils", ":segmenter_utils",
":task_context", ":task_context",
"@org_tensorflow//tensorflow/core:lib",
], ],
) )
...@@ -365,7 +365,6 @@ cc_library( ...@@ -365,7 +365,6 @@ cc_library(
":utils", ":utils",
":whole_sentence_features", ":whole_sentence_features",
":workspace", ":workspace",
"@org_tensorflow//tensorflow/core:lib",
], ],
alwayslink = 1, alwayslink = 1,
) )
...@@ -390,6 +389,7 @@ cc_library( ...@@ -390,6 +389,7 @@ cc_library(
srcs = ["embedding_feature_extractor.cc"], srcs = ["embedding_feature_extractor.cc"],
hdrs = ["embedding_feature_extractor.h"], hdrs = ["embedding_feature_extractor.h"],
deps = [ deps = [
":base",
":feature_extractor", ":feature_extractor",
":parser_transitions", ":parser_transitions",
":sentence_features", ":sentence_features",
...@@ -397,7 +397,6 @@ cc_library( ...@@ -397,7 +397,6 @@ cc_library(
":task_context", ":task_context",
":utils", ":utils",
":workspace", ":workspace",
"@org_tensorflow//tensorflow/core:lib",
], ],
) )
...@@ -455,6 +454,7 @@ cc_library( ...@@ -455,6 +454,7 @@ cc_library(
srcs = ["lexicon_builder.cc"], srcs = ["lexicon_builder.cc"],
deps = [ deps = [
":affix", ":affix",
":base",
":char_ngram_string_extractor", ":char_ngram_string_extractor",
":feature_extractor", ":feature_extractor",
":parser_transitions", ":parser_transitions",
...@@ -464,8 +464,6 @@ cc_library( ...@@ -464,8 +464,6 @@ cc_library(
":term_frequency_map", ":term_frequency_map",
":text_formats", ":text_formats",
":utils", ":utils",
"@org_tensorflow//tensorflow/core:framework",
"@org_tensorflow//tensorflow/core:lib",
], ],
alwayslink = 1, alwayslink = 1,
) )
...@@ -484,11 +482,11 @@ cc_library( ...@@ -484,11 +482,11 @@ cc_library(
name = "parser_ops_cc", name = "parser_ops_cc",
srcs = ["ops/parser_ops.cc"], srcs = ["ops/parser_ops.cc"],
deps = [ deps = [
":base",
":document_filters", ":document_filters",
":lexicon_builder", ":lexicon_builder",
":reader_ops", ":reader_ops",
":unpack_sparse_features", ":unpack_sparse_features",
"@org_tensorflow//tensorflow/core:framework",
], ],
alwayslink = 1, alwayslink = 1,
) )
......
...@@ -478,11 +478,15 @@ class GreedyParser(object): ...@@ -478,11 +478,15 @@ class GreedyParser(object):
"""Embeddings at the given index will be set to pretrained values.""" """Embeddings at the given index will be set to pretrained values."""
def _Initializer(shape, dtype=tf.float32, partition_info=None): def _Initializer(shape, dtype=tf.float32, partition_info=None):
"""Variable initializer that loads pretrained embeddings."""
unused_dtype = dtype unused_dtype = dtype
seed1, seed2 = tf.get_seed(self._seed)
t = gen_parser_ops.word_embedding_initializer( t = gen_parser_ops.word_embedding_initializer(
vectors=embeddings_path, vectors=embeddings_path,
task_context=task_context, task_context=task_context,
embedding_init=self._embedding_init) embedding_init=self._embedding_init,
seed=seed1,
seed2=seed2)
t.set_shape(shape) t.set_shape(shape)
return t return t
......
...@@ -249,6 +249,8 @@ REGISTER_OP("WordEmbeddingInitializer") ...@@ -249,6 +249,8 @@ REGISTER_OP("WordEmbeddingInitializer")
.Attr("vectors: string") .Attr("vectors: string")
.Attr("task_context: string") .Attr("task_context: string")
.Attr("embedding_init: float = 1.0") .Attr("embedding_init: float = 1.0")
.Attr("seed: int = 0")
.Attr("seed2: int = 0")
.Doc(R"doc( .Doc(R"doc(
Reads word embeddings from an sstable of dist_belief.TokenEmbedding protos for Reads word embeddings from an sstable of dist_belief.TokenEmbedding protos for
every word specified in a text vocabulary file. every word specified in a text vocabulary file.
...@@ -256,6 +258,13 @@ every word specified in a text vocabulary file. ...@@ -256,6 +258,13 @@ every word specified in a text vocabulary file.
word_embeddings: a tensor containing word embeddings from the specified sstable. word_embeddings: a tensor containing word embeddings from the specified sstable.
vectors: path to recordio of word embedding vectors. vectors: path to recordio of word embedding vectors.
task_context: file path at which to read the task context. task_context: file path at which to read the task context.
embedding_init: embedding vectors that are not found in the input sstable are
initialized randomly from a normal distribution with zero mean and
std dev = embedding_init / sqrt(embedding_size).
seed: If either `seed` or `seed2` are set to be non-zero, the random number
generator is seeded by the given seed. Otherwise, it is seeded by a random
seed.
seed2: A second seed to avoid seed collision.
)doc"); )doc");
REGISTER_OP("DocumentSource") REGISTER_OP("DocumentSource")
......
...@@ -450,6 +450,13 @@ class WordEmbeddingInitializer : public OpKernel { ...@@ -450,6 +450,13 @@ class WordEmbeddingInitializer : public OpKernel {
OP_REQUIRES_OK(context, OP_REQUIRES_OK(context,
context->GetAttr("embedding_init", &embedding_init_)); context->GetAttr("embedding_init", &embedding_init_));
// Convert the seeds into a single 64-bit seed. NB: seed=0,seed2=0 converts
// into seed_=0, which causes Eigen PRNGs to seed non-deterministically.
int seed, seed2;
OP_REQUIRES_OK(context, context->GetAttr("seed", &seed));
OP_REQUIRES_OK(context, context->GetAttr("seed2", &seed2));
seed_ = static_cast<uint64>(seed) | static_cast<uint64>(seed2) << 32;
// Sets up number and type of inputs and outputs. // Sets up number and type of inputs and outputs.
OP_REQUIRES_OK(context, context->MatchSignature({}, {DT_FLOAT})); OP_REQUIRES_OK(context, context->MatchSignature({}, {DT_FLOAT}));
} }
...@@ -479,11 +486,10 @@ class WordEmbeddingInitializer : public OpKernel { ...@@ -479,11 +486,10 @@ class WordEmbeddingInitializer : public OpKernel {
context, context->allocate_output( context, context->allocate_output(
0, TensorShape({word_map->Size() + 3, embedding_size}), 0, TensorShape({word_map->Size() + 3, embedding_size}),
&embedding_matrix)); &embedding_matrix));
embedding_matrix->matrix<float>() auto matrix = embedding_matrix->matrix<float>();
.setRandom<Eigen::internal::NormalRandomGenerator<float>>(); Eigen::internal::NormalRandomGenerator<float> prng(seed_);
embedding_matrix->matrix<float>() = matrix =
embedding_matrix->matrix<float>() * static_cast<float>( matrix.random(prng) * (embedding_init_ / sqrtf(embedding_size));
embedding_init_ / sqrt(embedding_size));
} }
if (vocab.find(embedding.token()) != vocab.end()) { if (vocab.find(embedding.token()) != vocab.end()) {
SetNormalizedRow(embedding.vector(), vocab[embedding.token()], SetNormalizedRow(embedding.vector(), vocab[embedding.token()],
...@@ -544,6 +550,9 @@ class WordEmbeddingInitializer : public OpKernel { ...@@ -544,6 +550,9 @@ class WordEmbeddingInitializer : public OpKernel {
// Task context used to configure this op. // Task context used to configure this op.
TaskContext task_context_; TaskContext task_context_;
// Seed for random initialization.
uint64 seed_ = 0;
// Embedding vectors that are not found in the input sstable are initialized // Embedding vectors that are not found in the input sstable are initialized
// randomly from a normal distribution with zero mean and // randomly from a normal distribution with zero mean and
// std dev = embedding_init_ / sqrt(embedding_size). // std dev = embedding_init_ / sqrt(embedding_size).
......
...@@ -167,19 +167,19 @@ class ParsingReaderOpsTest(test_util.TensorFlowTestCase): ...@@ -167,19 +167,19 @@ class ParsingReaderOpsTest(test_util.TensorFlowTestCase):
logging.info('Result: %s', res) logging.info('Result: %s', res)
self.assertEqual(res[0], 2) self.assertEqual(res[0], 2)
def testWordEmbeddingInitializer(self): def _token_embedding(self, token, embedding):
def _TokenEmbedding(token, embedding): e = dictionary_pb2.TokenEmbedding()
e = dictionary_pb2.TokenEmbedding() e.token = token
e.token = token e.vector.values.extend(embedding)
e.vector.values.extend(embedding) return e.SerializeToString()
return e.SerializeToString()
def testWordEmbeddingInitializer(self):
# Provide embeddings for the first three words in the word map. # Provide embeddings for the first three words in the word map.
records_path = os.path.join(FLAGS.test_tmpdir, 'sstable-00000-of-00001') records_path = os.path.join(FLAGS.test_tmpdir, 'records1')
writer = tf.python_io.TFRecordWriter(records_path) writer = tf.python_io.TFRecordWriter(records_path)
writer.write(_TokenEmbedding('.', [1, 2])) writer.write(self._token_embedding('.', [1, 2]))
writer.write(_TokenEmbedding(',', [3, 4])) writer.write(self._token_embedding(',', [3, 4]))
writer.write(_TokenEmbedding('the', [5, 6])) writer.write(self._token_embedding('the', [5, 6]))
del writer del writer
with self.test_session(): with self.test_session():
...@@ -192,6 +192,34 @@ class ParsingReaderOpsTest(test_util.TensorFlowTestCase): ...@@ -192,6 +192,34 @@ class ParsingReaderOpsTest(test_util.TensorFlowTestCase):
[5. / (25 + 36) ** .5, 6. / (25 + 36) ** .5]]), [5. / (25 + 36) ** .5, 6. / (25 + 36) ** .5]]),
embeddings[:3,]) embeddings[:3,])
def testWordEmbeddingInitializerRepeatability(self):
records_path = os.path.join(FLAGS.test_tmpdir, 'records2')
writer = tf.python_io.TFRecordWriter(records_path)
writer.write(self._token_embedding('.', [1, 2, 3])) # 3 dims
del writer
# As long as there is one non-zero seed, the result should be repeatable.
for seed1, seed2 in [(0, 1), (1, 0), (123, 456)]:
with tf.Graph().as_default(), self.test_session():
embeddings1 = gen_parser_ops.word_embedding_initializer(
vectors=records_path,
task_context=self._task_context,
seed=seed1,
seed2=seed2)
embeddings2 = gen_parser_ops.word_embedding_initializer(
vectors=records_path,
task_context=self._task_context,
seed=seed1,
seed2=seed2)
# The number of terms is based on the word map, which may change if the
# test corpus is updated. Just assert that there are some terms.
self.assertGreater(tf.shape(embeddings1)[0].eval(), 0)
self.assertGreater(tf.shape(embeddings2)[0].eval(), 0)
self.assertEqual(tf.shape(embeddings1)[1].eval(), 3)
self.assertEqual(tf.shape(embeddings2)[1].eval(), 3)
self.assertAllEqual(embeddings1.eval(), embeddings2.eval())
if __name__ == '__main__': if __name__ == '__main__':
googletest.main() googletest.main()
Subproject commit f7ed0682f67a9a767ee30ad62233847e8a8cbb95 Subproject commit a7d6015d3759bee447c8103979a5ebc831ce23d1
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment