Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
a57a00f6
Commit
a57a00f6
authored
Jul 17, 2017
by
Mark Omernick
Committed by
GitHub
Jul 17, 2017
Browse files
Merge pull request #1959 from tensorflow/add_fixed_embeddings
Adds an op to handle pre-computed word embeddings.
parents
3646eef8
ccf606c3
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
247 additions
and
45 deletions
+247
-45
syntaxnet/dragnn/python/BUILD
syntaxnet/dragnn/python/BUILD
+1
-0
syntaxnet/dragnn/python/network_units.py
syntaxnet/dragnn/python/network_units.py
+12
-9
syntaxnet/syntaxnet/BUILD
syntaxnet/syntaxnet/BUILD
+9
-0
syntaxnet/syntaxnet/ops/parser_ops.cc
syntaxnet/syntaxnet/ops/parser_ops.cc
+15
-4
syntaxnet/syntaxnet/reader_ops.cc
syntaxnet/syntaxnet/reader_ops.cc
+118
-31
syntaxnet/syntaxnet/reader_ops_test.py
syntaxnet/syntaxnet/reader_ops_test.py
+71
-1
syntaxnet/syntaxnet/syntaxnet_ops.py
syntaxnet/syntaxnet/syntaxnet_ops.py
+21
-0
No files found.
syntaxnet/dragnn/python/BUILD
View file @
a57a00f6
...
@@ -89,6 +89,7 @@ py_library(
...
@@ -89,6 +89,7 @@ py_library(
srcs
=
[
"network_units.py"
],
srcs
=
[
"network_units.py"
],
deps
=
[
deps
=
[
":dragnn_ops"
,
":dragnn_ops"
,
"//syntaxnet:syntaxnet_ops"
,
"//syntaxnet/util:check"
,
"//syntaxnet/util:check"
,
"//syntaxnet/util:pyregistry"
,
"//syntaxnet/util:pyregistry"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
...
...
syntaxnet/dragnn/python/network_units.py
View file @
a57a00f6
...
@@ -15,9 +15,11 @@
...
@@ -15,9 +15,11 @@
"""Basic network units used in assembling DRAGNN graphs."""
"""Basic network units used in assembling DRAGNN graphs."""
from
abc
import
ABCMeta
from
__future__
import
absolute_import
from
abc
import
abstractmethod
from
__future__
import
division
from
__future__
import
print_function
import
abc
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.python.ops
import
nn
from
tensorflow.python.ops
import
nn
...
@@ -25,6 +27,7 @@ from tensorflow.python.ops import tensor_array_ops as ta
...
@@ -25,6 +27,7 @@ from tensorflow.python.ops import tensor_array_ops as ta
from
tensorflow.python.platform
import
tf_logging
as
logging
from
tensorflow.python.platform
import
tf_logging
as
logging
from
dragnn.python
import
dragnn_ops
from
dragnn.python
import
dragnn_ops
from
syntaxnet
import
syntaxnet_ops
from
syntaxnet.util
import
check
from
syntaxnet.util
import
check
from
syntaxnet.util
import
registry
from
syntaxnet.util
import
registry
...
@@ -135,11 +138,11 @@ def add_embeddings(channel_id, feature_spec, seed=None):
...
@@ -135,11 +138,11 @@ def add_embeddings(channel_id, feature_spec, seed=None):
raise
RuntimeError
(
'vocab resource contains more than one part:
\n
%s'
,
raise
RuntimeError
(
'vocab resource contains more than one part:
\n
%s'
,
str
(
feature_spec
.
vocab
))
str
(
feature_spec
.
vocab
))
seed1
,
seed2
=
tf
.
get_seed
(
seed
)
seed1
,
seed2
=
tf
.
get_seed
(
seed
)
embeddings
=
dragnn_ops
.
dragnn
_embedding_initializer
(
embeddings
=
syntaxnet_ops
.
word
_embedding_initializer
(
embedding_input
=
feature_spec
.
pretrained_embedding_matrix
.
part
[
0
]
vectors
=
feature_spec
.
pretrained_embedding_matrix
.
part
[
0
]
.
file_pattern
,
.
file_pattern
,
vocabulary
=
feature_spec
.
vocab
.
part
[
0
]
.
file_pattern
,
vocab
=
feature_spec
.
vocab
.
part
[
0
].
file_pattern
,
num_special_embeddings
=
1
,
scaling_coefficien
t
=
1.0
,
embedding_ini
t
=
1.0
,
seed
=
seed1
,
seed
=
seed1
,
seed2
=
seed2
)
seed2
=
seed2
)
return
tf
.
get_variable
(
name
,
initializer
=
tf
.
reshape
(
embeddings
,
shape
))
return
tf
.
get_variable
(
name
,
initializer
=
tf
.
reshape
(
embeddings
,
shape
))
...
@@ -626,7 +629,7 @@ class NetworkUnitInterface(object):
...
@@ -626,7 +629,7 @@ class NetworkUnitInterface(object):
layers (list): List of Layer objects to track network layers that should
layers (list): List of Layer objects to track network layers that should
be written to Tensors during training and inference.
be written to Tensors during training and inference.
"""
"""
__metaclass__
=
ABCMeta
# required for @abstractmethod
__metaclass__
=
abc
.
ABCMeta
# required for @abstractmethod
def
__init__
(
self
,
component
,
init_layers
=
None
,
init_context_layers
=
None
):
def
__init__
(
self
,
component
,
init_layers
=
None
,
init_context_layers
=
None
):
"""Initializes parameters for embedding matrices.
"""Initializes parameters for embedding matrices.
...
@@ -738,7 +741,7 @@ class NetworkUnitInterface(object):
...
@@ -738,7 +741,7 @@ class NetworkUnitInterface(object):
[
attention_hidden_layer_size
,
component
.
num_actions
],
[
attention_hidden_layer_size
,
component
.
num_actions
],
initializer
=
tf
.
random_normal_initializer
(
stddev
=
1e-4
)))
initializer
=
tf
.
random_normal_initializer
(
stddev
=
1e-4
)))
@
abstractmethod
@
abc
.
abstractmethod
def
create
(
self
,
def
create
(
self
,
fixed_embeddings
,
fixed_embeddings
,
linked_embeddings
,
linked_embeddings
,
...
...
syntaxnet/syntaxnet/BUILD
View file @
a57a00f6
...
@@ -747,6 +747,15 @@ py_library(
...
@@ -747,6 +747,15 @@ py_library(
data
=
[
":parser_ops.so"
],
data
=
[
":parser_ops.so"
],
)
)
py_library
(
name
=
"syntaxnet_ops"
,
srcs
=
[
"syntaxnet_ops.py"
],
deps
=
[
":parser_ops"
,
":load_parser_ops_py"
,
],
)
py_library
(
py_library
(
name
=
"graph_builder"
,
name
=
"graph_builder"
,
srcs
=
[
"graph_builder.py"
],
srcs
=
[
"graph_builder.py"
],
...
...
syntaxnet/syntaxnet/ops/parser_ops.cc
View file @
a57a00f6
...
@@ -247,7 +247,10 @@ weights: vector of weight extracted from the SparseFeatures proto.
...
@@ -247,7 +247,10 @@ weights: vector of weight extracted from the SparseFeatures proto.
REGISTER_OP
(
"WordEmbeddingInitializer"
)
REGISTER_OP
(
"WordEmbeddingInitializer"
)
.
Output
(
"word_embeddings: float"
)
.
Output
(
"word_embeddings: float"
)
.
Attr
(
"vectors: string"
)
.
Attr
(
"vectors: string"
)
.
Attr
(
"task_context: string"
)
.
Attr
(
"task_context: string = ''"
)
.
Attr
(
"vocabulary: string = ''"
)
.
Attr
(
"cache_vectors_locally: bool = true"
)
.
Attr
(
"num_special_embeddings: int = 3"
)
.
Attr
(
"embedding_init: float = 1.0"
)
.
Attr
(
"embedding_init: float = 1.0"
)
.
Attr
(
"seed: int = 0"
)
.
Attr
(
"seed: int = 0"
)
.
Attr
(
"seed2: int = 0"
)
.
Attr
(
"seed2: int = 0"
)
...
@@ -255,9 +258,17 @@ REGISTER_OP("WordEmbeddingInitializer")
...
@@ -255,9 +258,17 @@ REGISTER_OP("WordEmbeddingInitializer")
Reads word embeddings from an sstable of dist_belief.TokenEmbedding protos for
Reads word embeddings from an sstable of dist_belief.TokenEmbedding protos for
every word specified in a text vocabulary file.
every word specified in a text vocabulary file.
word_embeddings: a tensor containing word embeddings from the specified sstable.
word_embeddings: a tensor containing word embeddings from the specified table.
vectors: path to recordio of word embedding vectors.
vectors: path to TF record file of word embedding vectors.
task_context: file path at which to read the task context.
task_context: file path at which to read the task context, for its "word-map"
input. Exactly one of `task_context` or `vocabulary` must be specified.
vocabulary: path to vocabulary file, which contains one unique word per line, in
order. Exactly one of `task_context` or `vocabulary` must be specified.
cache_vectors_locally: Whether to cache the vectors file to a local temp file
before parsing it. This greatly reduces initialization time when the vectors
are stored remotely, but requires that "/tmp" has sufficient space.
num_special_embeddings: Number of special embeddings to allocate, in addition to
those allocated for real words.
embedding_init: embedding vectors that are not found in the input sstable are
embedding_init: embedding vectors that are not found in the input sstable are
initialized randomly from a normal distribution with zero mean and
initialized randomly from a normal distribution with zero mean and
std dev = embedding_init / sqrt(embedding_size).
std dev = embedding_init / sqrt(embedding_size).
...
...
syntaxnet/syntaxnet/reader_ops.cc
View file @
a57a00f6
...
@@ -34,9 +34,11 @@ limitations under the License.
...
@@ -34,9 +34,11 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/io/table.h"
#include "tensorflow/core/lib/io/table.h"
#include "tensorflow/core/lib/io/table_options.h"
#include "tensorflow/core/lib/io/table_options.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/env.h"
...
@@ -439,14 +441,18 @@ class WordEmbeddingInitializer : public OpKernel {
...
@@ -439,14 +441,18 @@ class WordEmbeddingInitializer : public OpKernel {
public:
public:
explicit
WordEmbeddingInitializer
(
OpKernelConstruction
*
context
)
explicit
WordEmbeddingInitializer
(
OpKernelConstruction
*
context
)
:
OpKernel
(
context
)
{
:
OpKernel
(
context
)
{
string
file_path
,
data
;
OP_REQUIRES_OK
(
context
,
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"task_context"
,
&
file
_path
));
context
->
GetAttr
(
"task_context"
,
&
task_context
_path
_
));
OP_REQUIRES_OK
(
context
,
ReadFileToString
(
tensorflow
::
Env
::
Default
(),
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"vocabulary"
,
&
vocabulary_path_
));
file_path
,
&
data
));
OP_REQUIRES
(
OP_REQUIRES
(
context
,
context
,
task_context_path_
.
empty
()
!=
vocabulary_path_
.
empty
()
,
TextFormat
::
ParseFromString
(
data
,
task_context_
.
mutable_spec
()),
InvalidArgument
(
InvalidArgument
(
"Could not parse
task
context
at "
,
file_path
));
"Exactly one of
task
_
context
or vocabulary must be specified"
));
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"vectors"
,
&
vectors_path_
));
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"vectors"
,
&
vectors_path_
));
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"cache_vectors_locally"
,
&
cache_vectors_locally_
));
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"num_special_embeddings"
,
&
num_special_embeddings_
));
OP_REQUIRES_OK
(
context
,
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"embedding_init"
,
&
embedding_init_
));
context
->
GetAttr
(
"embedding_init"
,
&
embedding_init_
));
...
@@ -462,43 +468,117 @@ class WordEmbeddingInitializer : public OpKernel {
...
@@ -462,43 +468,117 @@ class WordEmbeddingInitializer : public OpKernel {
}
}
void
Compute
(
OpKernelContext
*
context
)
override
{
void
Compute
(
OpKernelContext
*
context
)
override
{
// Loads words from vocabulary with mapping to ids.
std
::
unordered_map
<
string
,
int64
>
vocab
;
string
path
=
TaskContext
::
InputFile
(
*
task_context_
.
GetInput
(
"word-map"
));
OP_REQUIRES_OK
(
context
,
LoadVocabulary
(
&
vocab
));
const
TermFrequencyMap
*
word_map
=
SharedStoreUtils
::
GetWithDefaultName
<
TermFrequencyMap
>
(
path
,
0
,
0
);
unordered_map
<
string
,
int64
>
vocab
;
for
(
int
i
=
0
;
i
<
word_map
->
Size
();
++
i
)
{
vocab
[
word_map
->
GetTerm
(
i
)]
=
i
;
}
// Creates a reader pointing to a local copy of the vectors recordio.
string
vectors_path
=
vectors_path_
;
string
tmp_vectors_path
;
if
(
cache_vectors_locally_
)
{
OP_REQUIRES_OK
(
context
,
CopyToTmpPath
(
vectors_path_
,
&
tmp_vectors_path
));
OP_REQUIRES_OK
(
context
,
CopyToTmpPath
(
vectors_path_
,
&
vectors_path
));
ProtoRecordReader
reader
(
tmp_vectors_path
);
}
ProtoRecordReader
reader
(
vectors_path
);
// Loads the embedding vectors into a matrix.
// Load the embedding vectors into a matrix. Since the |embedding_matrix|
// output cannot be allocated until the embedding dimension is known, delay
// allocation until the first iteration of the loop.
Tensor
*
embedding_matrix
=
nullptr
;
Tensor
*
embedding_matrix
=
nullptr
;
TokenEmbedding
embedding
;
TokenEmbedding
embedding
;
while
(
reader
.
Read
(
&
embedding
)
==
tensorflow
::
Status
::
OK
())
{
while
(
reader
.
Read
(
&
embedding
)
==
tensorflow
::
Status
::
OK
())
{
if
(
embedding_matrix
==
nullptr
)
{
if
(
embedding_matrix
==
nullptr
)
{
const
int
embedding_size
=
embedding
.
vector
().
values_size
();
OP_REQUIRES_OK
(
context
,
OP_REQUIRES_OK
(
InitRandomEmbeddingMatrix
(
vocab
,
embedding
,
context
,
context
,
context
->
allocate_output
(
&
embedding_matrix
));
0
,
TensorShape
({
word_map
->
Size
()
+
3
,
embedding_size
}),
&
embedding_matrix
));
auto
matrix
=
embedding_matrix
->
matrix
<
float
>
();
Eigen
::
internal
::
NormalRandomGenerator
<
float
>
prng
(
seed_
);
matrix
=
matrix
.
random
(
prng
)
*
(
embedding_init_
/
sqrtf
(
embedding_size
));
}
}
if
(
vocab
.
find
(
embedding
.
token
())
!=
vocab
.
end
())
{
if
(
vocab
.
find
(
embedding
.
token
())
!=
vocab
.
end
())
{
SetNormalizedRow
(
embedding
.
vector
(),
vocab
[
embedding
.
token
()],
SetNormalizedRow
(
embedding
.
vector
(),
vocab
[
embedding
.
token
()],
embedding_matrix
);
embedding_matrix
);
}
}
}
}
// The vectors file might not contain any embeddings (perhaps due to read
// errors), in which case the |embedding_matrix| output is never allocated.
// Signal this error early instead of letting downstream ops complain about
// a missing input.
OP_REQUIRES
(
context
,
embedding_matrix
!=
nullptr
,
InvalidArgument
(
tensorflow
::
strings
::
StrCat
(
"found no pretrained embeddings in vectors="
,
vectors_path_
,
" vocabulary="
,
vocabulary_path_
,
" vocab_size="
,
vocab
.
size
())));
}
}
private:
private:
// Loads the vocabulary from the task context or vocabulary.
tensorflow
::
Status
LoadVocabulary
(
std
::
unordered_map
<
string
,
int64
>
*
vocabulary
)
const
{
if
(
!
task_context_path_
.
empty
())
{
return
LoadVocabularyFromTaskContext
(
vocabulary
);
}
else
{
return
LoadVocabularyFromFile
(
vocabulary
);
}
}
// Loads the |vocabulary| from the "word-map" input of the task context at
// |task_context_path_|, or returns non-OK on error.
tensorflow
::
Status
LoadVocabularyFromTaskContext
(
std
::
unordered_map
<
string
,
int64
>
*
vocabulary
)
const
{
vocabulary
->
clear
();
string
textproto
;
TF_RETURN_IF_ERROR
(
ReadFileToString
(
tensorflow
::
Env
::
Default
(),
task_context_path_
,
&
textproto
));
TaskContext
task_context
;
if
(
!
TextFormat
::
ParseFromString
(
textproto
,
task_context
.
mutable_spec
()))
{
return
InvalidArgument
(
"Could not parse task context at "
,
task_context_path_
);
}
const
string
path
=
TaskContext
::
InputFile
(
*
task_context
.
GetInput
(
"word-map"
));
const
TermFrequencyMap
*
word_map
=
SharedStoreUtils
::
GetWithDefaultName
<
TermFrequencyMap
>
(
path
,
0
,
0
);
for
(
int
i
=
0
;
i
<
word_map
->
Size
();
++
i
)
{
(
*
vocabulary
)[
word_map
->
GetTerm
(
i
)]
=
i
;
}
return
tensorflow
::
Status
::
OK
();
}
// Loads the |vocabulary| from the |vocabulary_path_| file, which contains one
// word per line in order, or returns non-OK on error.
tensorflow
::
Status
LoadVocabularyFromFile
(
std
::
unordered_map
<
string
,
int64
>
*
vocabulary
)
const
{
vocabulary
->
clear
();
string
text
;
TF_RETURN_IF_ERROR
(
ReadFileToString
(
tensorflow
::
Env
::
Default
(),
vocabulary_path_
,
&
text
));
// Chomp a trailing newline, if any, to avoid producing a spurious empty
// term at the end of the vocabulary file.
if
(
!
text
.
empty
()
&&
text
.
back
()
==
'\n'
)
text
.
pop_back
();
for
(
const
string
&
line
:
tensorflow
::
str_util
::
Split
(
text
,
"
\n
"
))
{
if
(
vocabulary
->
find
(
line
)
!=
vocabulary
->
end
())
{
return
InvalidArgument
(
"Vocabulary file at "
,
vocabulary_path_
,
" contains multiple instances of term: "
,
line
);
}
const
int64
index
=
vocabulary
->
size
();
(
*
vocabulary
)[
line
]
=
index
;
}
return
tensorflow
::
Status
::
OK
();
}
// Allocates the |embedding_matrix| based on the |vocabulary| and |embedding|
// and initializes it to random values, or returns non-OK on error.
tensorflow
::
Status
InitRandomEmbeddingMatrix
(
const
std
::
unordered_map
<
string
,
int64
>
&
vocabulary
,
const
TokenEmbedding
&
embedding
,
OpKernelContext
*
context
,
Tensor
**
embedding_matrix
)
const
{
const
int
rows
=
vocabulary
.
size
()
+
num_special_embeddings_
;
const
int
columns
=
embedding
.
vector
().
values_size
();
TF_RETURN_IF_ERROR
(
context
->
allocate_output
(
0
,
TensorShape
({
rows
,
columns
}),
embedding_matrix
));
auto
matrix
=
(
*
embedding_matrix
)
->
matrix
<
float
>
();
Eigen
::
internal
::
NormalRandomGenerator
<
float
>
prng
(
seed_
);
matrix
=
matrix
.
random
(
prng
)
*
(
embedding_init_
/
sqrtf
(
columns
));
return
tensorflow
::
Status
::
OK
();
}
// Sets embedding_matrix[row] to a normalized version of the given vector.
// Sets embedding_matrix[row] to a normalized version of the given vector.
void
SetNormalizedRow
(
const
TokenEmbedding
::
Vector
&
vector
,
const
int
row
,
void
SetNormalizedRow
(
const
TokenEmbedding
::
Vector
&
vector
,
const
int
row
,
Tensor
*
embedding_matrix
)
{
Tensor
*
embedding_matrix
)
{
...
@@ -547,8 +627,15 @@ class WordEmbeddingInitializer : public OpKernel {
...
@@ -547,8 +627,15 @@ class WordEmbeddingInitializer : public OpKernel {
}
}
}
}
// Task context used to configure this op.
// Path to the task context or vocabulary. Exactly one must be specified.
TaskContext
task_context_
;
string
task_context_path_
;
string
vocabulary_path_
;
// Whether to cache the vectors to a local temp file, to reduce I/O latency.
bool
cache_vectors_locally_
=
true
;
// Number of special embeddings to allocate.
int
num_special_embeddings_
=
3
;
// Seed for random initialization.
// Seed for random initialization.
uint64
seed_
=
0
;
uint64
seed_
=
0
;
...
...
syntaxnet/syntaxnet/reader_ops_test.py
View file @
a57a00f6
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
"""Tests for reader_ops."""
"""Tests for reader_ops."""
# pylint: disable=no-name-in-module,unused-import,g-bad-import-order,maybe-no-member,no-member,g-importing-member
import
os.path
import
os.path
import
numpy
as
np
import
numpy
as
np
...
@@ -29,7 +30,6 @@ from syntaxnet import graph_builder
...
@@ -29,7 +30,6 @@ from syntaxnet import graph_builder
from
syntaxnet
import
sparse_pb2
from
syntaxnet
import
sparse_pb2
from
syntaxnet.ops
import
gen_parser_ops
from
syntaxnet.ops
import
gen_parser_ops
FLAGS
=
tf
.
app
.
flags
.
FLAGS
FLAGS
=
tf
.
app
.
flags
.
FLAGS
if
not
hasattr
(
FLAGS
,
'test_srcdir'
):
if
not
hasattr
(
FLAGS
,
'test_srcdir'
):
FLAGS
.
test_srcdir
=
''
FLAGS
.
test_srcdir
=
''
...
@@ -220,6 +220,76 @@ class ParsingReaderOpsTest(test_util.TensorFlowTestCase):
...
@@ -220,6 +220,76 @@ class ParsingReaderOpsTest(test_util.TensorFlowTestCase):
self
.
assertEqual
(
tf
.
shape
(
embeddings2
)[
1
].
eval
(),
3
)
self
.
assertEqual
(
tf
.
shape
(
embeddings2
)[
1
].
eval
(),
3
)
self
.
assertAllEqual
(
embeddings1
.
eval
(),
embeddings2
.
eval
())
self
.
assertAllEqual
(
embeddings1
.
eval
(),
embeddings2
.
eval
())
def
testWordEmbeddingInitializerFailIfNeitherTaskContextOrVocabulary
(
self
):
with
self
.
test_session
():
with
self
.
assertRaises
(
Exception
):
gen_parser_ops
.
word_embedding_initializer
(
vectors
=
'/dev/null'
).
eval
()
def
testWordEmbeddingInitializerFailIfBothTaskContextAndVocabulary
(
self
):
with
self
.
test_session
():
with
self
.
assertRaises
(
Exception
):
gen_parser_ops
.
word_embedding_initializer
(
vectors
=
'/dev/null'
,
task_context
=
'/dev/null'
,
vocabulary
=
'/dev/null'
).
eval
()
def
testWordEmbeddingInitializerVocabularyFile
(
self
):
records_path
=
os
.
path
.
join
(
FLAGS
.
test_tmpdir
,
'records3'
)
writer
=
tf
.
python_io
.
TFRecordWriter
(
records_path
)
writer
.
write
(
self
.
_token_embedding
(
'a'
,
[
1
,
2
,
3
]))
writer
.
write
(
self
.
_token_embedding
(
'b'
,
[
2
,
3
,
4
]))
writer
.
write
(
self
.
_token_embedding
(
'c'
,
[
3
,
4
,
5
]))
writer
.
write
(
self
.
_token_embedding
(
'd'
,
[
4
,
5
,
6
]))
writer
.
write
(
self
.
_token_embedding
(
'e'
,
[
5
,
6
,
7
]))
del
writer
vocabulary_path
=
os
.
path
.
join
(
FLAGS
.
test_tmpdir
,
'vocabulary3'
)
with
open
(
vocabulary_path
,
'w'
)
as
vocabulary_file
:
vocabulary_file
.
write
(
'a
\n
c
\n
e
\n
x
\n
'
)
# 'x' not in pretrained embeddings
# Enumerate a variety of configurations.
for
cache_vectors_locally
in
[
False
,
True
]:
for
num_special_embeddings
in
[
None
,
1
,
2
,
5
]:
# None = use default of 3
with
self
.
test_session
():
embeddings
=
gen_parser_ops
.
word_embedding_initializer
(
vectors
=
records_path
,
vocabulary
=
vocabulary_path
,
cache_vectors_locally
=
cache_vectors_locally
,
num_special_embeddings
=
num_special_embeddings
)
# Expect 4 embeddings from the vocabulary plus special embeddings.
expected_num_embeddings
=
4
+
(
num_special_embeddings
or
3
)
self
.
assertAllEqual
([
expected_num_embeddings
,
3
],
tf
.
shape
(
embeddings
).
eval
())
# The first 3 embeddings should be pretrained.
norm_a
=
(
1.0
+
4.0
+
9.0
)
**
0.5
norm_c
=
(
9.0
+
16.0
+
25.0
)
**
0.5
norm_e
=
(
25.0
+
36.0
+
49.0
)
**
0.5
self
.
assertAllClose
([[
1.0
/
norm_a
,
2.0
/
norm_a
,
3.0
/
norm_a
],
[
3.0
/
norm_c
,
4.0
/
norm_c
,
5.0
/
norm_c
],
[
5.0
/
norm_e
,
6.0
/
norm_e
,
7.0
/
norm_e
]],
embeddings
[:
3
].
eval
())
def
testWordEmbeddingInitializerVocabularyFileWithDuplicates
(
self
):
records_path
=
os
.
path
.
join
(
FLAGS
.
test_tmpdir
,
'records4'
)
writer
=
tf
.
python_io
.
TFRecordWriter
(
records_path
)
writer
.
write
(
self
.
_token_embedding
(
'a'
,
[
1
,
2
,
3
]))
writer
.
write
(
self
.
_token_embedding
(
'b'
,
[
2
,
3
,
4
]))
writer
.
write
(
self
.
_token_embedding
(
'c'
,
[
3
,
4
,
5
]))
writer
.
write
(
self
.
_token_embedding
(
'd'
,
[
4
,
5
,
6
]))
writer
.
write
(
self
.
_token_embedding
(
'e'
,
[
5
,
6
,
7
]))
del
writer
vocabulary_path
=
os
.
path
.
join
(
FLAGS
.
test_tmpdir
,
'vocabulary4'
)
with
open
(
vocabulary_path
,
'w'
)
as
vocabulary_file
:
vocabulary_file
.
write
(
'a
\n
c
\n
e
\n
x
\n
y
\n
x'
)
# 'x' duplicated
with
self
.
test_session
():
with
self
.
assertRaises
(
Exception
):
gen_parser_ops
.
word_embedding_initializer
(
vectors
=
records_path
,
vocabulary
=
vocabulary_path
).
eval
()
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
googletest
.
main
()
googletest
.
main
()
syntaxnet/syntaxnet/syntaxnet_ops.py
0 → 100644
View file @
a57a00f6
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Imports the SyntaxNet ops and their C++ implementations."""
from
syntaxnet.ops.gen_parser_ops
import
*
# pylint: disable=wildcard-import
import
syntaxnet.load_parser_ops
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment