Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
32ab5a58
Commit
32ab5a58
authored
May 12, 2016
by
calberti
Committed by
Martin Wicke
May 12, 2016
Browse files
Adding SyntaxNet to tensorflow/models (#63)
parent
148a15fb
Changes
131
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3599 additions
and
0 deletions
+3599
-0
syntaxnet/syntaxnet/proto_io.h
syntaxnet/syntaxnet/proto_io.h
+242
-0
syntaxnet/syntaxnet/reader_ops.cc
syntaxnet/syntaxnet/reader_ops.cc
+563
-0
syntaxnet/syntaxnet/reader_ops_test.py
syntaxnet/syntaxnet/reader_ops_test.py
+198
-0
syntaxnet/syntaxnet/registry.cc
syntaxnet/syntaxnet/registry.cc
+28
-0
syntaxnet/syntaxnet/registry.h
syntaxnet/syntaxnet/registry.h
+243
-0
syntaxnet/syntaxnet/sentence.proto
syntaxnet/syntaxnet/sentence.proto
+61
-0
syntaxnet/syntaxnet/sentence_batch.cc
syntaxnet/syntaxnet/sentence_batch.cc
+45
-0
syntaxnet/syntaxnet/sentence_batch.h
syntaxnet/syntaxnet/sentence_batch.h
+78
-0
syntaxnet/syntaxnet/sentence_features.cc
syntaxnet/syntaxnet/sentence_features.cc
+192
-0
syntaxnet/syntaxnet/sentence_features.h
syntaxnet/syntaxnet/sentence_features.h
+317
-0
syntaxnet/syntaxnet/sentence_features_test.cc
syntaxnet/syntaxnet/sentence_features_test.cc
+155
-0
syntaxnet/syntaxnet/shared_store.cc
syntaxnet/syntaxnet/shared_store.cc
+91
-0
syntaxnet/syntaxnet/shared_store.h
syntaxnet/syntaxnet/shared_store.h
+234
-0
syntaxnet/syntaxnet/shared_store_test.cc
syntaxnet/syntaxnet/shared_store_test.cc
+242
-0
syntaxnet/syntaxnet/sparse.proto
syntaxnet/syntaxnet/sparse.proto
+19
-0
syntaxnet/syntaxnet/structured_graph_builder.py
syntaxnet/syntaxnet/structured_graph_builder.py
+240
-0
syntaxnet/syntaxnet/syntaxnet.bzl
syntaxnet/syntaxnet/syntaxnet.bzl
+107
-0
syntaxnet/syntaxnet/tagger_transitions.cc
syntaxnet/syntaxnet/tagger_transitions.cc
+258
-0
syntaxnet/syntaxnet/tagger_transitions_test.cc
syntaxnet/syntaxnet/tagger_transitions_test.cc
+113
-0
syntaxnet/syntaxnet/task_context.cc
syntaxnet/syntaxnet/task_context.cc
+173
-0
No files found.
syntaxnet/syntaxnet/proto_io.h
0 → 100644
View file @
32ab5a58
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef $TARGETDIR_PROTO_IO_H_
#define $TARGETDIR_PROTO_IO_H_
#include <iostream>
#include <memory>
#include <string>
#include <vector>
#include "syntaxnet/document_format.h"
#include "syntaxnet/feature_extractor.pb.h"
#include "syntaxnet/feature_types.h"
#include "syntaxnet/registry.h"
#include "syntaxnet/sentence.pb.h"
#include "syntaxnet/task_context.h"
#include "syntaxnet/utils.h"
#include "syntaxnet/workspace.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/io/inputbuffer.h"
#include "tensorflow/core/lib/io/record_reader.h"
#include "tensorflow/core/lib/io/record_writer.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
namespace
syntaxnet
{
// A convenience wrapper to read protos with a RecordReader.
class
ProtoRecordReader
{
public:
explicit
ProtoRecordReader
(
tensorflow
::
RandomAccessFile
*
file
)
:
file_
(
file
),
reader_
(
new
tensorflow
::
io
::
RecordReader
(
file_
))
{}
explicit
ProtoRecordReader
(
const
string
&
filename
)
{
TF_CHECK_OK
(
tensorflow
::
Env
::
Default
()
->
NewRandomAccessFile
(
filename
,
&
file_
));
reader_
.
reset
(
new
tensorflow
::
io
::
RecordReader
(
file_
));
}
~
ProtoRecordReader
()
{
reader_
.
reset
();
delete
file_
;
}
template
<
typename
T
>
tensorflow
::
Status
Read
(
T
*
proto
)
{
string
buffer
;
tensorflow
::
Status
status
=
reader_
->
ReadRecord
(
&
offset_
,
&
buffer
);
if
(
status
.
ok
())
{
CHECK
(
proto
->
ParseFromString
(
buffer
));
return
tensorflow
::
Status
::
OK
();
}
else
{
return
status
;
}
}
private:
tensorflow
::
RandomAccessFile
*
file_
=
nullptr
;
uint64
offset_
=
0
;
std
::
unique_ptr
<
tensorflow
::
io
::
RecordReader
>
reader_
;
};
// A convenience wrapper to write protos with a RecordReader.
class
ProtoRecordWriter
{
public:
explicit
ProtoRecordWriter
(
const
string
&
filename
)
{
TF_CHECK_OK
(
tensorflow
::
Env
::
Default
()
->
NewWritableFile
(
filename
,
&
file_
));
writer_
.
reset
(
new
tensorflow
::
io
::
RecordWriter
(
file_
));
}
~
ProtoRecordWriter
()
{
writer_
.
reset
();
delete
file_
;
}
template
<
typename
T
>
void
Write
(
const
T
&
proto
)
{
TF_CHECK_OK
(
writer_
->
WriteRecord
(
proto
.
SerializeAsString
()));
}
private:
tensorflow
::
WritableFile
*
file_
=
nullptr
;
std
::
unique_ptr
<
tensorflow
::
io
::
RecordWriter
>
writer_
;
};
// A file implementation to read from stdin.
class
StdIn
:
public
tensorflow
::
RandomAccessFile
{
public:
StdIn
()
{}
~
StdIn
()
override
{}
// Reads up to n bytes from standard input. Returns `OUT_OF_RANGE` if fewer
// than n bytes were stored in `*result` because of EOF.
tensorflow
::
Status
Read
(
uint64
offset
,
size_t
n
,
tensorflow
::
StringPiece
*
result
,
char
*
scratch
)
const
override
{
CHECK_EQ
(
expected_offset_
,
offset
);
if
(
!
eof_
)
{
string
line
;
eof_
=
!
std
::
getline
(
std
::
cin
,
line
);
buffer_
.
append
(
line
);
buffer_
.
append
(
"
\n
"
);
}
CopyFromBuffer
(
std
::
min
(
buffer_
.
size
(),
n
),
result
,
scratch
);
if
(
eof_
)
{
return
tensorflow
::
errors
::
OutOfRange
(
"End of file reached"
);
}
else
{
return
tensorflow
::
Status
::
OK
();
}
}
private:
void
CopyFromBuffer
(
size_t
n
,
tensorflow
::
StringPiece
*
result
,
char
*
scratch
)
const
{
memcpy
(
scratch
,
buffer_
.
data
(),
buffer_
.
size
());
buffer_
=
buffer_
.
substr
(
n
);
result
->
set
(
scratch
,
n
);
expected_offset_
+=
n
;
}
mutable
bool
eof_
=
false
;
mutable
int64
expected_offset_
=
0
;
mutable
string
buffer_
;
TF_DISALLOW_COPY_AND_ASSIGN
(
StdIn
);
};
// Reads sentence protos from a text file.
class
TextReader
{
public:
explicit
TextReader
(
const
TaskInput
&
input
)
{
CHECK_EQ
(
input
.
record_format_size
(),
1
)
<<
"TextReader only supports inputs with one record format: "
<<
input
.
DebugString
();
CHECK_EQ
(
input
.
part_size
(),
1
)
<<
"TextReader only supports inputs with one part: "
<<
input
.
DebugString
();
filename_
=
TaskContext
::
InputFile
(
input
);
format_
.
reset
(
DocumentFormat
::
Create
(
input
.
record_format
(
0
)));
Reset
();
}
Sentence
*
Read
()
{
// Skips emtpy sentences, e.g., blank lines at the beginning of a file or
// commented out blocks.
vector
<
Sentence
*>
sentences
;
string
key
,
value
;
while
(
sentences
.
empty
()
&&
format_
->
ReadRecord
(
buffer_
.
get
(),
&
value
))
{
key
=
tensorflow
::
strings
::
StrCat
(
filename_
,
":"
,
sentence_count_
);
format_
->
ConvertFromString
(
key
,
value
,
&
sentences
);
CHECK_LE
(
sentences
.
size
(),
1
);
}
if
(
sentences
.
empty
())
{
// End of file reached.
return
nullptr
;
}
else
{
++
sentence_count_
;
return
sentences
[
0
];
}
}
void
Reset
()
{
sentence_count_
=
0
;
tensorflow
::
RandomAccessFile
*
file
;
if
(
filename_
==
"-"
)
{
static
const
int
kInputBufferSize
=
8
*
1024
;
/* bytes */
file
=
new
StdIn
();
buffer_
.
reset
(
new
tensorflow
::
io
::
InputBuffer
(
file
,
kInputBufferSize
));
}
else
{
static
const
int
kInputBufferSize
=
1
*
1024
*
1024
;
/* bytes */
TF_CHECK_OK
(
tensorflow
::
Env
::
Default
()
->
NewRandomAccessFile
(
filename_
,
&
file
));
buffer_
.
reset
(
new
tensorflow
::
io
::
InputBuffer
(
file
,
kInputBufferSize
));
}
}
private:
string
filename_
;
int
sentence_count_
=
0
;
std
::
unique_ptr
<
tensorflow
::
io
::
InputBuffer
>
buffer_
;
std
::
unique_ptr
<
DocumentFormat
>
format_
;
};
// Writes sentence protos to a text conll file.
class
TextWriter
{
public:
explicit
TextWriter
(
const
TaskInput
&
input
)
{
CHECK_EQ
(
input
.
record_format_size
(),
1
)
<<
"TextWriter only supports files with one record format: "
<<
input
.
DebugString
();
CHECK_EQ
(
input
.
part_size
(),
1
)
<<
"TextWriter only supports files with one part: "
<<
input
.
DebugString
();
filename_
=
TaskContext
::
InputFile
(
input
);
format_
.
reset
(
DocumentFormat
::
Create
(
input
.
record_format
(
0
)));
if
(
filename_
!=
"-"
)
{
TF_CHECK_OK
(
tensorflow
::
Env
::
Default
()
->
NewWritableFile
(
filename_
,
&
file_
));
}
}
~
TextWriter
()
{
if
(
file_
)
{
file_
->
Close
();
delete
file_
;
}
}
void
Write
(
const
Sentence
&
sentence
)
{
string
key
,
value
;
format_
->
ConvertToString
(
sentence
,
&
key
,
&
value
);
if
(
file_
)
{
TF_CHECK_OK
(
file_
->
Append
(
value
));
}
else
{
std
::
cout
<<
value
;
}
}
private:
string
filename_
;
std
::
unique_ptr
<
DocumentFormat
>
format_
;
tensorflow
::
WritableFile
*
file_
=
nullptr
;
};
}
// namespace syntaxnet
#endif // $TARGETDIR_PROTO_IO_H_
syntaxnet/syntaxnet/reader_ops.cc
0 → 100644
View file @
32ab5a58
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <math.h>
#include <deque>
#include <unordered_map>
#include <memory>
#include <string>
#include <vector>
#include "syntaxnet/base.h"
#include "syntaxnet/feature_extractor.h"
#include "syntaxnet/parser_state.h"
#include "syntaxnet/parser_transitions.h"
#include "syntaxnet/sentence_batch.h"
#include "syntaxnet/sentence.pb.h"
#include "syntaxnet/shared_store.h"
#include "syntaxnet/sparse.pb.h"
#include "syntaxnet/task_context.h"
#include "syntaxnet/task_spec.pb.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/io/inputbuffer.h"
#include "tensorflow/core/lib/io/table.h"
#include "tensorflow/core/lib/io/table_options.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
using
tensorflow
::
DEVICE_CPU
;
using
tensorflow
::
DT_FLOAT
;
using
tensorflow
::
DT_INT32
;
using
tensorflow
::
DT_INT64
;
using
tensorflow
::
DT_STRING
;
using
tensorflow
::
DataType
;
using
tensorflow
::
OpKernel
;
using
tensorflow
::
OpKernelConstruction
;
using
tensorflow
::
OpKernelContext
;
using
tensorflow
::
Tensor
;
using
tensorflow
::
TensorShape
;
using
tensorflow
::
error
::
OUT_OF_RANGE
;
using
tensorflow
::
errors
::
InvalidArgument
;
namespace
syntaxnet
{
class
ParsingReader
:
public
OpKernel
{
public:
explicit
ParsingReader
(
OpKernelConstruction
*
context
)
:
OpKernel
(
context
)
{
string
file_path
,
corpus_name
;
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"task_context"
,
&
file_path
));
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"feature_size"
,
&
feature_size_
));
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"batch_size"
,
&
max_batch_size_
));
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"corpus_name"
,
&
corpus_name
));
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"arg_prefix"
,
&
arg_prefix_
));
// Reads task context from file.
string
data
;
OP_REQUIRES_OK
(
context
,
ReadFileToString
(
tensorflow
::
Env
::
Default
(),
file_path
,
&
data
));
OP_REQUIRES
(
context
,
TextFormat
::
ParseFromString
(
data
,
task_context_
.
mutable_spec
()),
InvalidArgument
(
"Could not parse task context at "
,
file_path
));
// Set up the batch reader.
sentence_batch_
.
reset
(
new
SentenceBatch
(
max_batch_size_
,
corpus_name
));
sentence_batch_
->
Init
(
&
task_context_
);
// Set up the parsing features and transition system.
states_
.
resize
(
max_batch_size_
);
workspaces_
.
resize
(
max_batch_size_
);
features_
.
reset
(
new
ParserEmbeddingFeatureExtractor
(
arg_prefix_
));
features_
->
Setup
(
&
task_context_
);
transition_system_
.
reset
(
ParserTransitionSystem
::
Create
(
task_context_
.
Get
(
features_
->
GetParamName
(
"transition_system"
),
"arc-standard"
)));
transition_system_
->
Setup
(
&
task_context_
);
features_
->
Init
(
&
task_context_
);
features_
->
RequestWorkspaces
(
&
workspace_registry_
);
transition_system_
->
Init
(
&
task_context_
);
string
label_map_path
=
TaskContext
::
InputFile
(
*
task_context_
.
GetInput
(
"label-map"
));
label_map_
=
SharedStoreUtils
::
GetWithDefaultName
<
TermFrequencyMap
>
(
label_map_path
,
0
,
0
);
// Checks number of feature groups matches the task context.
const
int
required_size
=
features_
->
embedding_dims
().
size
();
OP_REQUIRES
(
context
,
feature_size_
==
required_size
,
InvalidArgument
(
"Task context requires feature_size="
,
required_size
));
}
~
ParsingReader
()
override
{
SharedStore
::
Release
(
label_map_
);
}
// Creates a new ParserState if there's another sentence to be read.
virtual
void
AdvanceSentence
(
int
index
)
{
states_
[
index
].
reset
();
if
(
sentence_batch_
->
AdvanceSentence
(
index
))
{
states_
[
index
].
reset
(
new
ParserState
(
sentence_batch_
->
sentence
(
index
),
transition_system_
->
NewTransitionState
(
true
),
label_map_
));
workspaces_
[
index
].
Reset
(
workspace_registry_
);
features_
->
Preprocess
(
&
workspaces_
[
index
],
states_
[
index
].
get
());
}
}
void
Compute
(
OpKernelContext
*
context
)
override
{
mutex_lock
lock
(
mu_
);
// Advances states to the next positions.
PerformActions
(
context
);
// Advances any final states to the next sentences.
for
(
int
i
=
0
;
i
<
max_batch_size_
;
++
i
)
{
if
(
state
(
i
)
==
nullptr
)
continue
;
// Switches to the next sentence if we're at a final state.
while
(
transition_system_
->
IsFinalState
(
*
state
(
i
)))
{
VLOG
(
2
)
<<
"Advancing sentence "
<<
i
;
AdvanceSentence
(
i
);
if
(
state
(
i
)
==
nullptr
)
break
;
// EOF has been reached
}
}
// Rewinds if no states remain in the batch (we need to re-wind the corpus).
if
(
sentence_batch_
->
size
()
==
0
)
{
++
num_epochs_
;
LOG
(
INFO
)
<<
"Starting epoch "
<<
num_epochs_
;
sentence_batch_
->
Rewind
();
for
(
int
i
=
0
;
i
<
max_batch_size_
;
++
i
)
AdvanceSentence
(
i
);
}
// Create the outputs for each feature space.
vector
<
Tensor
*>
feature_outputs
(
features_
->
NumEmbeddings
());
for
(
size_t
i
=
0
;
i
<
feature_outputs
.
size
();
++
i
)
{
OP_REQUIRES_OK
(
context
,
context
->
allocate_output
(
i
,
TensorShape
({
sentence_batch_
->
size
(),
features_
->
FeatureSize
(
i
)}),
&
feature_outputs
[
i
]));
}
// Populate feature outputs.
for
(
int
i
=
0
,
index
=
0
;
i
<
max_batch_size_
;
++
i
)
{
if
(
states_
[
i
]
==
nullptr
)
continue
;
// Extract features from the current parser state, and fill up the
// available batch slots.
std
::
vector
<
std
::
vector
<
SparseFeatures
>>
features
=
features_
->
ExtractSparseFeatures
(
workspaces_
[
i
],
*
states_
[
i
]);
for
(
size_t
feature_space
=
0
;
feature_space
<
features
.
size
();
++
feature_space
)
{
int
feature_size
=
features
[
feature_space
].
size
();
CHECK
(
feature_size
==
features_
->
FeatureSize
(
feature_space
));
auto
features_output
=
feature_outputs
[
feature_space
]
->
matrix
<
string
>
();
for
(
int
k
=
0
;
k
<
feature_size
;
++
k
)
{
features_output
(
index
,
k
)
=
features
[
feature_space
][
k
].
SerializeAsString
();
}
}
++
index
;
}
// Return the number of epochs.
Tensor
*
epoch_output
;
OP_REQUIRES_OK
(
context
,
context
->
allocate_output
(
feature_size_
,
TensorShape
({}),
&
epoch_output
));
auto
num_epochs
=
epoch_output
->
scalar
<
int32
>
();
num_epochs
()
=
num_epochs_
;
// Create outputs specific to this reader.
AddAdditionalOutputs
(
context
);
}
protected:
// Peforms any relevant actions on the parser states, typically either
// the gold action or a predicted action from decoding.
virtual
void
PerformActions
(
OpKernelContext
*
context
)
=
0
;
// Adds outputs specific to this reader starting at additional_output_index().
virtual
void
AddAdditionalOutputs
(
OpKernelContext
*
context
)
const
=
0
;
// Returns the output type specification of the this base class.
std
::
vector
<
DataType
>
default_outputs
()
const
{
std
::
vector
<
DataType
>
output_types
(
feature_size_
,
DT_STRING
);
output_types
.
push_back
(
DT_INT32
);
return
output_types
;
}
// Accessors.
int
max_batch_size
()
const
{
return
max_batch_size_
;
}
int
batch_size
()
const
{
return
sentence_batch_
->
size
();
}
int
additional_output_index
()
const
{
return
feature_size_
+
1
;
}
ParserState
*
state
(
int
i
)
const
{
return
states_
[
i
].
get
();
}
const
ParserTransitionSystem
&
transition_system
()
const
{
return
*
transition_system_
.
get
();
}
// Parser task context.
const
TaskContext
&
task_context
()
const
{
return
task_context_
;
}
const
string
&
arg_prefix
()
const
{
return
arg_prefix_
;
}
private:
// Task context used to configure this op.
TaskContext
task_context_
;
// Prefix for context parameters.
string
arg_prefix_
;
// mutex to synchronize access to Compute.
mutex
mu_
;
// How many times the document source has been rewinded.
int
num_epochs_
=
0
;
// How many sentences this op can be processing at any given time.
int
max_batch_size_
=
1
;
// Number of feature groups in the brain parser features.
int
feature_size_
=
-
1
;
// Batch of sentences, and the corresponding parser states.
std
::
unique_ptr
<
SentenceBatch
>
sentence_batch_
;
// Batch: ParserState objects.
std
::
vector
<
std
::
unique_ptr
<
ParserState
>>
states_
;
// Batch: WorkspaceSet objects.
std
::
vector
<
WorkspaceSet
>
workspaces_
;
// Dependency label map used in transition system.
const
TermFrequencyMap
*
label_map_
;
// Transition system.
std
::
unique_ptr
<
ParserTransitionSystem
>
transition_system_
;
// Typed feature extractor for embeddings.
std
::
unique_ptr
<
ParserEmbeddingFeatureExtractor
>
features_
;
// Internal workspace registry for use in feature extraction.
WorkspaceRegistry
workspace_registry_
;
TF_DISALLOW_COPY_AND_ASSIGN
(
ParsingReader
);
};
class
GoldParseReader
:
public
ParsingReader
{
public:
explicit
GoldParseReader
(
OpKernelConstruction
*
context
)
:
ParsingReader
(
context
)
{
// Sets up number and type of inputs and outputs.
std
::
vector
<
DataType
>
output_types
=
default_outputs
();
output_types
.
push_back
(
DT_INT32
);
OP_REQUIRES_OK
(
context
,
context
->
MatchSignature
({},
output_types
));
}
private:
// Always performs the next gold action for each state.
void
PerformActions
(
OpKernelContext
*
context
)
override
{
for
(
int
i
=
0
;
i
<
max_batch_size
();
++
i
)
{
if
(
state
(
i
)
!=
nullptr
)
{
transition_system
().
PerformAction
(
transition_system
().
GetNextGoldAction
(
*
state
(
i
)),
state
(
i
));
}
}
}
// Adds the list of gold actions for each state as an additional output.
void
AddAdditionalOutputs
(
OpKernelContext
*
context
)
const
override
{
Tensor
*
actions_output
;
OP_REQUIRES_OK
(
context
,
context
->
allocate_output
(
additional_output_index
(),
TensorShape
({
batch_size
()}),
&
actions_output
));
// Add all gold actions for non-null states as an additional output.
auto
gold_actions
=
actions_output
->
vec
<
int32
>
();
for
(
int
i
=
0
,
batch_index
=
0
;
i
<
max_batch_size
();
++
i
)
{
if
(
state
(
i
)
!=
nullptr
)
{
const
int
gold_action
=
transition_system
().
GetNextGoldAction
(
*
state
(
i
));
gold_actions
(
batch_index
++
)
=
gold_action
;
}
}
}
TF_DISALLOW_COPY_AND_ASSIGN
(
GoldParseReader
);
};
REGISTER_KERNEL_BUILDER
(
Name
(
"GoldParseReader"
).
Device
(
DEVICE_CPU
),
GoldParseReader
);
// DecodedParseReader parses sentences using transition scores computed
// by a TensorFlow network. This op additionally computes a token correctness
// evaluation metric which can be used to select hyperparameter settings and
// training stopping point.
//
// The notion of correct token is determined by the transition system, e.g.
// a tagger will return POS tag accuracy, while an arc-standard parser will
// return UAS.
//
// Which tokens should be scored is controlled by the '<arg_prefix>_scoring'
// task parameter. Possible values are
// - 'default': skips tokens with only punctuation in the tag name.
// - 'conllx': skips tokens with only punctuation in the surface form.
// - 'ignore_parens': same as conllx, but skipping parentheses as well.
// - '': scores all tokens.
class
DecodedParseReader
:
public
ParsingReader
{
public:
explicit
DecodedParseReader
(
OpKernelConstruction
*
context
)
:
ParsingReader
(
context
)
{
// Sets up number and type of inputs and outputs.
std
::
vector
<
DataType
>
output_types
=
default_outputs
();
output_types
.
push_back
(
DT_INT32
);
output_types
.
push_back
(
DT_STRING
);
OP_REQUIRES_OK
(
context
,
context
->
MatchSignature
({
DT_FLOAT
},
output_types
));
// Gets scoring parameters.
scoring_type_
=
task_context
().
Get
(
tensorflow
::
strings
::
StrCat
(
arg_prefix
(),
"_scoring"
),
""
);
}
private:
void
AdvanceSentence
(
int
index
)
override
{
ParsingReader
::
AdvanceSentence
(
index
);
if
(
state
(
index
))
{
docids_
.
push_front
(
state
(
index
)
->
sentence
().
docid
());
}
}
// Tallies the # of correct and incorrect tokens for a given ParserState.
void
ComputeTokenAccuracy
(
const
ParserState
&
state
)
{
for
(
int
i
=
0
;
i
<
state
.
sentence
().
token_size
();
++
i
)
{
const
Token
&
token
=
state
.
GetToken
(
i
);
if
(
utils
::
PunctuationUtil
::
ScoreToken
(
token
.
word
(),
token
.
tag
(),
scoring_type_
))
{
++
num_tokens_
;
if
(
state
.
IsTokenCorrect
(
i
))
++
num_correct_
;
}
}
}
// Performs the allowed action with the highest score on the given state.
// Also records the accuracy whenver a terminal action is taken.
void
PerformActions
(
OpKernelContext
*
context
)
override
{
auto
scores_matrix
=
context
->
input
(
0
).
matrix
<
float
>
();
num_tokens_
=
0
;
num_correct_
=
0
;
for
(
int
i
=
0
,
batch_index
=
0
;
i
<
max_batch_size
();
++
i
)
{
ParserState
*
state
=
this
->
state
(
i
);
if
(
state
!=
nullptr
)
{
int
best_action
=
0
;
float
best_score
=
-
INFINITY
;
for
(
int
action
=
0
;
action
<
scores_matrix
.
dimension
(
1
);
++
action
)
{
float
score
=
scores_matrix
(
batch_index
,
action
);
if
(
score
>
best_score
&&
transition_system
().
IsAllowedAction
(
action
,
*
state
))
{
best_action
=
action
;
best_score
=
score
;
}
}
transition_system
().
PerformAction
(
best_action
,
state
);
// Update the # of scored correct tokens if this is the last state
// in the sentence and save the annotated document.
if
(
transition_system
().
IsFinalState
(
*
state
))
{
ComputeTokenAccuracy
(
*
state
);
sentence_map_
[
state
->
sentence
().
docid
()]
=
state
->
sentence
();
state
->
AddParseToDocument
(
&
sentence_map_
[
state
->
sentence
().
docid
()]);
}
++
batch_index
;
}
}
}
// Adds the evaluation metrics and annotated documents as additional outputs,
// if there were any terminal states.
void
AddAdditionalOutputs
(
OpKernelContext
*
context
)
const
override
{
Tensor
*
counts_output
;
OP_REQUIRES_OK
(
context
,
context
->
allocate_output
(
additional_output_index
(),
TensorShape
({
2
}),
&
counts_output
));
auto
eval_metrics
=
counts_output
->
vec
<
int32
>
();
eval_metrics
(
0
)
=
num_tokens_
;
eval_metrics
(
1
)
=
num_correct_
;
// Output annotated documents for each state. To preserve order, repeatedly
// pull from the back of the docids queue as long as the sentences have been
// completely processed. If the next document has not been completely
// processed yet, then the docid will not be found in 'sentence_map_'.
vector
<
Sentence
>
sentences
;
while
(
!
docids_
.
empty
()
&&
sentence_map_
.
find
(
docids_
.
back
())
!=
sentence_map_
.
end
())
{
sentences
.
emplace_back
(
sentence_map_
[
docids_
.
back
()]);
sentence_map_
.
erase
(
docids_
.
back
());
docids_
.
pop_back
();
}
Tensor
*
annotated_output
;
OP_REQUIRES_OK
(
context
,
context
->
allocate_output
(
additional_output_index
()
+
1
,
TensorShape
({
static_cast
<
int64
>
(
sentences
.
size
())}),
&
annotated_output
));
auto
document_output
=
annotated_output
->
vec
<
string
>
();
for
(
size_t
i
=
0
;
i
<
sentences
.
size
();
++
i
)
{
document_output
(
i
)
=
sentences
[
i
].
SerializeAsString
();
}
}
// State for eval metric computation.
int
num_tokens_
=
0
;
int
num_correct_
=
0
;
// Parameter for deciding which tokens to score.
string
scoring_type_
;
mutable
std
::
deque
<
string
>
docids_
;
mutable
map
<
string
,
Sentence
>
sentence_map_
;
TF_DISALLOW_COPY_AND_ASSIGN
(
DecodedParseReader
);
};
REGISTER_KERNEL_BUILDER
(
Name
(
"DecodedParseReader"
).
Device
(
DEVICE_CPU
),
DecodedParseReader
);
class
WordEmbeddingInitializer
:
public
OpKernel
{
public:
explicit
WordEmbeddingInitializer
(
OpKernelConstruction
*
context
)
:
OpKernel
(
context
)
{
string
file_path
,
data
;
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"task_context"
,
&
file_path
));
OP_REQUIRES_OK
(
context
,
ReadFileToString
(
tensorflow
::
Env
::
Default
(),
file_path
,
&
data
));
OP_REQUIRES
(
context
,
TextFormat
::
ParseFromString
(
data
,
task_context_
.
mutable_spec
()),
InvalidArgument
(
"Could not parse task context at "
,
file_path
));
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"vectors"
,
&
vectors_path_
));
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"embedding_init"
,
&
embedding_init_
));
// Sets up number and type of inputs and outputs.
OP_REQUIRES_OK
(
context
,
context
->
MatchSignature
({},
{
DT_FLOAT
}));
}
void
Compute
(
OpKernelContext
*
context
)
override
{
// Loads words from vocabulary with mapping to ids.
string
path
=
TaskContext
::
InputFile
(
*
task_context_
.
GetInput
(
"word-map"
));
const
TermFrequencyMap
*
word_map
=
SharedStoreUtils
::
GetWithDefaultName
<
TermFrequencyMap
>
(
path
,
0
,
0
);
unordered_map
<
string
,
int64
>
vocab
;
for
(
int
i
=
0
;
i
<
word_map
->
Size
();
++
i
)
{
vocab
[
word_map
->
GetTerm
(
i
)]
=
i
;
}
// Creates a reader pointing to a local copy of the vectors recordio.
string
tmp_vectors_path
;
OP_REQUIRES_OK
(
context
,
CopyToTmpPath
(
vectors_path_
,
&
tmp_vectors_path
));
ProtoRecordReader
reader
(
tmp_vectors_path
);
// Loads the embedding vectors into a matrix.
Tensor
*
embedding_matrix
=
nullptr
;
TokenEmbedding
embedding
;
while
(
reader
.
Read
(
&
embedding
)
==
tensorflow
::
Status
::
OK
())
{
if
(
embedding_matrix
==
nullptr
)
{
const
int
embedding_size
=
embedding
.
vector
().
values_size
();
OP_REQUIRES_OK
(
context
,
context
->
allocate_output
(
0
,
TensorShape
({
word_map
->
Size
()
+
3
,
embedding_size
}),
&
embedding_matrix
));
embedding_matrix
->
matrix
<
float
>
()
.
setRandom
<
Eigen
::
internal
::
NormalRandomGenerator
<
float
>>
();
embedding_matrix
->
matrix
<
float
>
()
=
embedding_matrix
->
matrix
<
float
>
()
*
static_cast
<
float
>
(
embedding_init_
/
sqrt
(
embedding_size
));
}
if
(
vocab
.
find
(
embedding
.
token
())
!=
vocab
.
end
())
{
SetNormalizedRow
(
embedding
.
vector
(),
vocab
[
embedding
.
token
()],
embedding_matrix
);
}
}
}
private:
// Sets embedding_matrix[row] to a normalized version of the given vector.
void
SetNormalizedRow
(
const
TokenEmbedding
::
Vector
&
vector
,
const
int
row
,
Tensor
*
embedding_matrix
)
{
float
norm
=
0.0
f
;
for
(
int
col
=
0
;
col
<
vector
.
values_size
();
++
col
)
{
float
val
=
vector
.
values
(
col
);
norm
+=
val
*
val
;
}
norm
=
sqrt
(
norm
);
for
(
int
col
=
0
;
col
<
vector
.
values_size
();
++
col
)
{
embedding_matrix
->
matrix
<
float
>
()(
row
,
col
)
=
vector
.
values
(
col
)
/
norm
;
}
}
// Copies the file at source_path to a temporary file and sets tmp_path to the
// temporary file's location. This is helpful since reading from non local
// files with a record reader can be very slow.
static
tensorflow
::
Status
CopyToTmpPath
(
const
string
&
source_path
,
string
*
tmp_path
)
{
// Opens source file.
tensorflow
::
RandomAccessFile
*
source_file
;
TF_RETURN_IF_ERROR
(
tensorflow
::
Env
::
Default
()
->
NewRandomAccessFile
(
source_path
,
&
source_file
));
std
::
unique_ptr
<
tensorflow
::
RandomAccessFile
>
source_file_deleter
(
source_file
);
// Creates destination file.
tensorflow
::
WritableFile
*
target_file
;
*
tmp_path
=
tensorflow
::
strings
::
Printf
(
"/tmp/%d.%lld"
,
getpid
(),
tensorflow
::
Env
::
Default
()
->
NowMicros
());
TF_RETURN_IF_ERROR
(
tensorflow
::
Env
::
Default
()
->
NewWritableFile
(
*
tmp_path
,
&
target_file
));
std
::
unique_ptr
<
tensorflow
::
WritableFile
>
target_file_deleter
(
target_file
);
// Performs copy.
tensorflow
::
Status
s
;
const
size_t
kBytesToRead
=
10
<<
20
;
// 10MB at a time.
string
scratch
;
scratch
.
resize
(
kBytesToRead
);
for
(
uint64
offset
=
0
;
s
.
ok
();
offset
+=
kBytesToRead
)
{
tensorflow
::
StringPiece
data
;
s
.
Update
(
source_file
->
Read
(
offset
,
kBytesToRead
,
&
data
,
&
scratch
[
0
]));
target_file
->
Append
(
data
);
}
if
(
s
.
code
()
==
OUT_OF_RANGE
)
{
return
tensorflow
::
Status
::
OK
();
}
else
{
return
s
;
}
}
// Task context used to configure this op.
TaskContext
task_context_
;
// Embedding vectors that are not found in the input sstable are initialized
// randomly from a normal distribution with zero mean and
// std dev = embedding_init_ / sqrt(embedding_size).
float
embedding_init_
=
1.
f
;
// Path to recordio with word embedding vectors.
string
vectors_path_
;
};
REGISTER_KERNEL_BUILDER
(
Name
(
"WordEmbeddingInitializer"
).
Device
(
DEVICE_CPU
),
WordEmbeddingInitializer
);
}
// namespace syntaxnet
syntaxnet/syntaxnet/reader_ops_test.py
0 → 100644
View file @
32ab5a58
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for reader_ops."""
import
os.path
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow.python.framework
import
test_util
from
tensorflow.python.ops
import
control_flow_ops
as
cf
from
tensorflow.python.platform
import
googletest
from
tensorflow.python.platform
import
logging
from
syntaxnet
import
dictionary_pb2
from
syntaxnet
import
graph_builder
from
syntaxnet
import
sparse_pb2
from
syntaxnet.ops
import
gen_parser_ops
FLAGS
=
tf
.
app
.
flags
.
FLAGS
if
not
hasattr
(
FLAGS
,
'test_srcdir'
):
FLAGS
.
test_srcdir
=
''
if
not
hasattr
(
FLAGS
,
'test_tmpdir'
):
FLAGS
.
test_tmpdir
=
tf
.
test
.
get_temp_dir
()
class
ParsingReaderOpsTest
(
test_util
.
TensorFlowTestCase
):
def
setUp
(
self
):
# Creates a task context with the correct testing paths.
initial_task_context
=
os
.
path
.
join
(
FLAGS
.
test_srcdir
,
'syntaxnet/'
'testdata/context.pbtxt'
)
self
.
_task_context
=
os
.
path
.
join
(
FLAGS
.
test_tmpdir
,
'context.pbtxt'
)
with
open
(
initial_task_context
,
'r'
)
as
fin
:
with
open
(
self
.
_task_context
,
'w'
)
as
fout
:
fout
.
write
(
fin
.
read
().
replace
(
'SRCDIR'
,
FLAGS
.
test_srcdir
)
.
replace
(
'OUTPATH'
,
FLAGS
.
test_tmpdir
))
# Creates necessary term maps.
with
self
.
test_session
()
as
sess
:
gen_parser_ops
.
lexicon_builder
(
task_context
=
self
.
_task_context
,
corpus_name
=
'training-corpus'
).
run
()
self
.
_num_features
,
self
.
_num_feature_ids
,
_
,
self
.
_num_actions
=
(
sess
.
run
(
gen_parser_ops
.
feature_size
(
task_context
=
self
.
_task_context
,
arg_prefix
=
'brain_parser'
)))
def
GetMaxId
(
self
,
sparse_features
):
max_id
=
0
for
x
in
sparse_features
:
for
y
in
x
:
f
=
sparse_pb2
.
SparseFeatures
()
f
.
ParseFromString
(
y
)
for
i
in
f
.
id
:
max_id
=
max
(
i
,
max_id
)
return
max_id
def
testParsingReaderOp
(
self
):
# Runs the reader over the test input for two epochs.
num_steps_a
=
0
num_actions
=
0
num_word_ids
=
0
num_tag_ids
=
0
num_label_ids
=
0
batch_size
=
10
with
self
.
test_session
()
as
sess
:
(
words
,
tags
,
labels
),
epochs
,
gold_actions
=
(
gen_parser_ops
.
gold_parse_reader
(
self
.
_task_context
,
3
,
batch_size
,
corpus_name
=
'training-corpus'
))
while
True
:
tf_gold_actions
,
tf_epochs
,
tf_words
,
tf_tags
,
tf_labels
=
(
sess
.
run
([
gold_actions
,
epochs
,
words
,
tags
,
labels
]))
num_steps_a
+=
1
num_actions
=
max
(
num_actions
,
max
(
tf_gold_actions
)
+
1
)
num_word_ids
=
max
(
num_word_ids
,
self
.
GetMaxId
(
tf_words
)
+
1
)
num_tag_ids
=
max
(
num_tag_ids
,
self
.
GetMaxId
(
tf_tags
)
+
1
)
num_label_ids
=
max
(
num_label_ids
,
self
.
GetMaxId
(
tf_labels
)
+
1
)
self
.
assertIn
(
tf_epochs
,
[
0
,
1
,
2
])
if
tf_epochs
>
1
:
break
# Runs the reader again, this time with a lot of added graph nodes.
num_steps_b
=
0
with
self
.
test_session
()
as
sess
:
num_features
=
[
6
,
6
,
4
]
num_feature_ids
=
[
num_word_ids
,
num_tag_ids
,
num_label_ids
]
embedding_sizes
=
[
8
,
8
,
8
]
hidden_layer_sizes
=
[
32
,
32
]
# Here we aim to test the iteration of the reader op in a complex network,
# not the GraphBuilder.
parser
=
graph_builder
.
GreedyParser
(
num_actions
,
num_features
,
num_feature_ids
,
embedding_sizes
,
hidden_layer_sizes
)
parser
.
AddTraining
(
self
.
_task_context
,
batch_size
,
corpus_name
=
'training-corpus'
)
sess
.
run
(
parser
.
inits
.
values
())
while
True
:
tf_epochs
,
tf_cost
,
_
=
sess
.
run
(
[
parser
.
training
[
'epochs'
],
parser
.
training
[
'cost'
],
parser
.
training
[
'train_op'
]])
num_steps_b
+=
1
self
.
assertGreaterEqual
(
tf_cost
,
0
)
self
.
assertIn
(
tf_epochs
,
[
0
,
1
,
2
])
if
tf_epochs
>
1
:
break
# Assert that the two runs made the exact same number of steps.
logging
.
info
(
'Number of steps in the two runs: %d, %d'
,
num_steps_a
,
num_steps_b
)
self
.
assertEqual
(
num_steps_a
,
num_steps_b
)
def
testParsingReaderOpWhileLoop
(
self
):
feature_size
=
3
batch_size
=
5
def
ParserEndpoints
():
return
gen_parser_ops
.
gold_parse_reader
(
self
.
_task_context
,
feature_size
,
batch_size
,
corpus_name
=
'training-corpus'
)
with
self
.
test_session
()
as
sess
:
# The 'condition' and 'body' functions expect as many arguments as there
# are loop variables. 'condition' depends on the 'epoch' loop variable
# only, so we disregard the remaining unused function arguments. 'body'
# returns a list of updated loop variables.
def
Condition
(
epoch
,
*
unused_args
):
return
tf
.
less
(
epoch
,
2
)
def
Body
(
epoch
,
num_actions
,
*
feature_args
):
# By adding one of the outputs of the reader op ('epoch') as a control
# dependency to the reader op we force the repeated evaluation of the
# reader op.
with
epoch
.
graph
.
control_dependencies
([
epoch
]):
features
,
epoch
,
gold_actions
=
ParserEndpoints
()
num_actions
=
tf
.
maximum
(
num_actions
,
tf
.
reduce_max
(
gold_actions
,
[
0
],
False
)
+
1
)
feature_ids
=
[]
for
i
in
range
(
len
(
feature_args
)):
feature_ids
.
append
(
features
[
i
])
return
[
epoch
,
num_actions
]
+
feature_ids
epoch
=
ParserEndpoints
()[
-
2
]
num_actions
=
tf
.
constant
(
0
)
loop_vars
=
[
epoch
,
num_actions
]
res
=
sess
.
run
(
cf
.
While
(
Condition
,
Body
,
loop_vars
,
parallel_iterations
=
1
))
logging
.
info
(
'Result: %s'
,
res
)
self
.
assertEqual
(
res
[
0
],
2
)
def
testWordEmbeddingInitializer
(
self
):
def
_TokenEmbedding
(
token
,
embedding
):
e
=
dictionary_pb2
.
TokenEmbedding
()
e
.
token
=
token
e
.
vector
.
values
.
extend
(
embedding
)
return
e
.
SerializeToString
()
# Provide embeddings for the first three words in the word map.
records_path
=
os
.
path
.
join
(
FLAGS
.
test_tmpdir
,
'sstable-00000-of-00001'
)
writer
=
tf
.
python_io
.
TFRecordWriter
(
records_path
)
writer
.
write
(
_TokenEmbedding
(
'.'
,
[
1
,
2
]))
writer
.
write
(
_TokenEmbedding
(
','
,
[
3
,
4
]))
writer
.
write
(
_TokenEmbedding
(
'the'
,
[
5
,
6
]))
del
writer
with
self
.
test_session
():
embeddings
=
gen_parser_ops
.
word_embedding_initializer
(
vectors
=
records_path
,
task_context
=
self
.
_task_context
).
eval
()
self
.
assertAllClose
(
np
.
array
([[
1.
/
(
1
+
4
)
**
.
5
,
2.
/
(
1
+
4
)
**
.
5
],
[
3.
/
(
9
+
16
)
**
.
5
,
4.
/
(
9
+
16
)
**
.
5
],
[
5.
/
(
25
+
36
)
**
.
5
,
6.
/
(
25
+
36
)
**
.
5
]]),
embeddings
[:
3
,])
if
__name__
==
'__main__'
:
googletest
.
main
()
syntaxnet/syntaxnet/registry.cc
0 → 100644
View file @
32ab5a58
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "syntaxnet/registry.h"
namespace
syntaxnet
{
// Global list of all component registries.
RegistryMetadata
*
global_registry_list
=
NULL
;
void
RegistryMetadata
::
Register
(
RegistryMetadata
*
registry
)
{
registry
->
set_link
(
global_registry_list
);
global_registry_list
=
registry
;
}
}
// namespace syntaxnet
syntaxnet/syntaxnet/registry.h
0 → 100644
View file @
32ab5a58
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Registry for component registration. These classes can be used for creating
// registries of components conforming to the same interface. This is useful for
// making a component-based architecture where the specific implementation
// classes can be selected at runtime. There is support for both class-based and
// instance based registries.
//
// Example:
// function.h:
//
// class Function : public RegisterableInstance<Function> {
// public:
// virtual double Evaluate(double x) = 0;
// };
//
// #define REGISTER_FUNCTION(type, component)
// REGISTER_INSTANCE_COMPONENT(Function, type, component);
//
// function.cc:
//
// REGISTER_INSTANCE_REGISTRY("function", Function);
//
// class Cos : public Function {
// public:
// double Evaluate(double x) { return cos(x); }
// };
//
// class Exp : public Function {
// public:
// double Evaluate(double x) { return exp(x); }
// };
//
// REGISTER_FUNCTION("cos", Cos);
// REGISTER_FUNCTION("exp", Exp);
//
// Function *f = Function::Lookup("cos");
// double result = f->Evaluate(arg);
#ifndef $TARGETDIR_REGISTRY_H_
#define $TARGETDIR_REGISTRY_H_
#include <string.h>
#include <string>
#include <vector>
#include "syntaxnet/utils.h"
namespace
syntaxnet
{
// Component metadata with information about name, class, and code location.
class
ComponentMetadata
{
public:
ComponentMetadata
(
const
char
*
name
,
const
char
*
class_name
,
const
char
*
file
,
int
line
)
:
name_
(
name
),
class_name_
(
class_name
),
file_
(
file
),
line_
(
line
),
link_
(
NULL
)
{}
// Returns component name.
const
char
*
name
()
const
{
return
name_
;
}
// Metadata objects can be linked in a list.
ComponentMetadata
*
link
()
const
{
return
link_
;
}
void
set_link
(
ComponentMetadata
*
link
)
{
link_
=
link
;
}
private:
// Component name.
const
char
*
name_
;
// Name of class for component.
const
char
*
class_name_
;
// Code file and location where the component was registered.
const
char
*
file_
;
int
line_
;
// Link to next metadata object in list.
ComponentMetadata
*
link_
;
};
// The master registry contains all registered component registries. A registry
// is not registered in the master registry until the first component of that
// type is registered.
class
RegistryMetadata
:
public
ComponentMetadata
{
public:
RegistryMetadata
(
const
char
*
name
,
const
char
*
class_name
,
const
char
*
file
,
int
line
,
ComponentMetadata
**
components
)
:
ComponentMetadata
(
name
,
class_name
,
file
,
line
),
components_
(
components
)
{}
// Registers a component registry in the master registry.
static
void
Register
(
RegistryMetadata
*
registry
);
private:
// Location of list of components in registry.
ComponentMetadata
**
components_
;
};
// Registry for components. An object can be registered with a type name in the
// registry. The named instances in the registry can be returned using the
// Lookup() method. The components in the registry are put into a linked list
// of components. It is important that the component registry can be statically
// initialized in order not to depend on initialization order.
template
<
class
T
>
struct
ComponentRegistry
{
typedef
ComponentRegistry
<
T
>
Self
;
// Component registration class.
class
Registrar
:
public
ComponentMetadata
{
public:
// Registers new component by linking itself into the component list of
// the registry.
Registrar
(
Self
*
registry
,
const
char
*
type
,
const
char
*
class_name
,
const
char
*
file
,
int
line
,
T
*
object
)
:
ComponentMetadata
(
type
,
class_name
,
file
,
line
),
object_
(
object
)
{
// Register registry in master registry if this is the first registered
// component of this type.
if
(
registry
->
components
==
NULL
)
{
RegistryMetadata
::
Register
(
new
RegistryMetadata
(
registry
->
name
,
registry
->
class_name
,
registry
->
file
,
registry
->
line
,
reinterpret_cast
<
ComponentMetadata
**>
(
&
registry
->
components
)));
}
// Register component in registry.
set_link
(
registry
->
components
);
registry
->
components
=
this
;
}
// Returns component type.
const
char
*
type
()
const
{
return
name
();
}
// Returns component object.
T
*
object
()
const
{
return
object_
;
}
// Returns the next component in the component list.
Registrar
*
next
()
const
{
return
static_cast
<
Registrar
*>
(
link
());
}
private:
// Component object.
T
*
object_
;
};
// Finds registrar for named component in registry.
const
Registrar
*
GetComponent
(
const
char
*
type
)
const
{
Registrar
*
r
=
components
;
while
(
r
!=
NULL
&&
strcmp
(
type
,
r
->
type
())
!=
0
)
r
=
r
->
next
();
if
(
r
==
NULL
)
{
LOG
(
FATAL
)
<<
"Unknown "
<<
name
<<
" component: '"
<<
type
<<
"'."
;
}
return
r
;
}
// Finds a named component in the registry.
T
*
Lookup
(
const
char
*
type
)
const
{
return
GetComponent
(
type
)
->
object
();
}
T
*
Lookup
(
const
string
&
type
)
const
{
return
Lookup
(
type
.
c_str
());
}
// Textual description of the kind of components in the registry.
const
char
*
name
;
// Base class name of component type.
const
char
*
class_name
;
// File and line where the registry is defined.
const
char
*
file
;
int
line
;
// Linked list of registered components.
Registrar
*
components
;
};
// Base class for registerable class-based components.
template
<
class
T
>
class
RegisterableClass
{
public:
// Factory function type.
typedef
T
*
(
Factory
)();
// Registry type.
typedef
ComponentRegistry
<
Factory
>
Registry
;
// Creates a new component instance.
static
T
*
Create
(
const
string
&
type
)
{
return
registry
()
->
Lookup
(
type
)();
}
// Returns registry for class.
static
Registry
*
registry
()
{
return
&
registry_
;
}
private:
// Registry for class.
static
Registry
registry_
;
};
// Base class for registerable instance-based components.
template
<
class
T
>
class
RegisterableInstance
{
public:
// Registry type.
typedef
ComponentRegistry
<
T
>
Registry
;
private:
// Registry for class.
static
Registry
registry_
;
};
#define REGISTER_CLASS_COMPONENT(base, type, component) \
static base *__##component##__factory() { return new component; } \
static base::Registry::Registrar __##component##__##registrar( \
base::registry(), type, #component, __FILE__, __LINE__, \
__##component##__factory)
#define REGISTER_CLASS_REGISTRY(type, classname) \
template <> \
classname::Registry RegisterableClass<classname>::registry_ = { \
type, #classname, __FILE__, __LINE__, NULL}
#define REGISTER_INSTANCE_COMPONENT(base, type, component) \
static base::Registry::Registrar __##component##__##registrar( \
base::registry(), type, #component, __FILE__, __LINE__, new component)
#define REGISTER_INSTANCE_REGISTRY(type, classname) \
template <> \
classname::Registry RegisterableInstance<classname>::registry_ = { \
type, #classname, __FILE__, __LINE__, NULL}
}
// namespace syntaxnet
#endif // $TARGETDIR_REGISTRY_H_
syntaxnet/syntaxnet/sentence.proto
0 → 100644
View file @
32ab5a58
// Protocol buffer specification for document analysis.
syntax
=
"proto2"
;
package
syntaxnet
;
// A Sentence contains the raw text contents of a sentence, as well as an
// analysis.
message
Sentence
{
// Identifier for document.
optional
string
docid
=
1
;
// Raw text contents of the sentence.
optional
string
text
=
2
;
// Tokenization of the sentence.
repeated
Token
token
=
3
;
extensions
1000
to
max
;
}
// A document token marks a span of bytes in the document text as a token
// or word.
message
Token
{
// Token word form.
required
string
word
=
1
;
// Start position of token in text.
required
int32
start
=
2
;
// End position of token in text. Gives index of last byte, not one past
// the last byte. If token came from lexer, excludes any trailing HTML tags.
required
int32
end
=
3
;
// Head of this token in the dependency tree: the id of the token which has an
// arc going to this one. If it is the root token of a sentence, then it is
// set to -1.
optional
int32
head
=
4
[
default
=
-
1
];
// Part-of-speech tag for token.
optional
string
tag
=
5
;
// Coarse-grained word category for token.
optional
string
category
=
6
;
// Label for dependency relation between this token and its head.
optional
string
label
=
7
;
// Break level for tokens that indicates how it was separated from the
// previous token in the text.
enum
BreakLevel
{
NO_BREAK
=
0
;
// No separation between tokens.
SPACE_BREAK
=
1
;
// Tokens separated by space.
LINE_BREAK
=
2
;
// Tokens separated by line break.
SENTENCE_BREAK
=
3
;
// Tokens separated by sentence break.
}
optional
BreakLevel
break_level
=
8
[
default
=
SPACE_BREAK
];
extensions
1000
to
max
;
}
syntaxnet/syntaxnet/sentence_batch.cc
0 → 100644
View file @
32ab5a58
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "syntaxnet/sentence_batch.h"
#include <memory>
#include <string>
#include <vector>
#include "syntaxnet/task_context.h"
namespace
syntaxnet
{
void
SentenceBatch
::
Init
(
TaskContext
*
context
)
{
reader_
.
reset
(
new
TextReader
(
*
context
->
GetInput
(
input_name_
)));
size_
=
0
;
}
bool
SentenceBatch
::
AdvanceSentence
(
int
index
)
{
if
(
sentences_
[
index
]
==
nullptr
)
++
size_
;
sentences_
[
index
].
reset
();
std
::
unique_ptr
<
Sentence
>
sentence
(
reader_
->
Read
());
if
(
sentence
==
nullptr
)
{
--
size_
;
return
false
;
}
// Preprocess the new sentence for the parser state.
sentences_
[
index
]
=
std
::
move
(
sentence
);
return
true
;
}
}
// namespace syntaxnet
syntaxnet/syntaxnet/sentence_batch.h
0 → 100644
View file @
32ab5a58
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef $TARGETDIR_SENTENCE_BATCH_H_
#define $TARGETDIR_SENTENCE_BATCH_H_
#include <memory>
#include <string>
#include <vector>
#include "syntaxnet/embedding_feature_extractor.h"
#include "syntaxnet/feature_extractor.h"
#include "syntaxnet/parser_state.h"
#include "syntaxnet/parser_transitions.h"
#include "syntaxnet/sentence.pb.h"
#include "syntaxnet/sparse.pb.h"
#include "syntaxnet/task_context.h"
#include "syntaxnet/task_spec.pb.h"
#include "syntaxnet/term_frequency_map.h"
namespace
syntaxnet
{
// Helper class to manage generating batches of preprocessed ParserState objects
// by reading in multiple sentences in parallel.
class
SentenceBatch
{
public:
SentenceBatch
(
int
batch_size
,
string
input_name
)
:
batch_size_
(
batch_size
),
input_name_
(
input_name
),
sentences_
(
batch_size
)
{}
// Initializes all resources and opens the corpus file.
void
Init
(
TaskContext
*
context
);
// Advances the index'th sentence in the batch to the next sentence. This will
// create and preprocess a new ParserState for that element. Returns false if
// EOF is reached (if EOF, also sets the state to be nullptr.)
bool
AdvanceSentence
(
int
index
);
// Rewinds the corpus reader.
void
Rewind
()
{
reader_
->
Reset
();
}
int
size
()
const
{
return
size_
;
}
Sentence
*
sentence
(
int
index
)
{
return
sentences_
[
index
].
get
();
}
private:
// Running tally of non-nullptr states in the batch.
int
size_
;
// Maximum number of states in the batch.
int
batch_size_
;
// Input to read from the TaskContext.
string
input_name_
;
// Reader for the corpus.
std
::
unique_ptr
<
TextReader
>
reader_
;
// Batch: Sentence objects.
std
::
vector
<
std
::
unique_ptr
<
Sentence
>>
sentences_
;
};
}
// namespace syntaxnet
#endif // $TARGETDIR_SENTENCE_BATCH_H_
syntaxnet/syntaxnet/sentence_features.cc
0 → 100644
View file @
32ab5a58
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "syntaxnet/sentence_features.h"
#include "syntaxnet/registry.h"
#include "util/utf8/unicodetext.h"
namespace
syntaxnet
{
TermFrequencyMapFeature
::~
TermFrequencyMapFeature
()
{
if
(
term_map_
!=
nullptr
)
{
SharedStore
::
Release
(
term_map_
);
term_map_
=
nullptr
;
}
}
void
TermFrequencyMapFeature
::
Setup
(
TaskContext
*
context
)
{
TokenLookupFeature
::
Setup
(
context
);
context
->
GetInput
(
input_name_
,
"text"
,
""
);
}
void
TermFrequencyMapFeature
::
Init
(
TaskContext
*
context
)
{
min_freq_
=
GetIntParameter
(
"min-freq"
,
0
);
max_num_terms_
=
GetIntParameter
(
"max-num-terms"
,
0
);
file_name_
=
context
->
InputFile
(
*
context
->
GetInput
(
input_name_
));
term_map_
=
SharedStoreUtils
::
GetWithDefaultName
<
TermFrequencyMap
>
(
file_name_
,
min_freq_
,
max_num_terms_
);
TokenLookupFeature
::
Init
(
context
);
}
string
TermFrequencyMapFeature
::
GetFeatureValueName
(
FeatureValue
value
)
const
{
if
(
value
==
UnknownValue
())
return
"<UNKNOWN>"
;
if
(
value
>=
0
&&
value
<
(
NumValues
()
-
1
))
{
return
term_map_
->
GetTerm
(
value
);
}
LOG
(
ERROR
)
<<
"Invalid feature value: "
<<
value
;
return
"<INVALID>"
;
}
string
TermFrequencyMapFeature
::
WorkspaceName
()
const
{
return
SharedStoreUtils
::
CreateDefaultName
(
"term-frequency-map"
,
input_name_
,
min_freq_
,
max_num_terms_
);
}
string
Hyphen
::
GetFeatureValueName
(
FeatureValue
value
)
const
{
switch
(
value
)
{
case
NO_HYPHEN
:
return
"NO_HYPHEN"
;
case
HAS_HYPHEN
:
return
"HAS_HYPHEN"
;
}
return
"<INVALID>"
;
}
FeatureValue
Hyphen
::
ComputeValue
(
const
Token
&
token
)
const
{
const
string
&
word
=
token
.
word
();
return
(
word
.
find
(
'-'
)
<
word
.
length
()
?
HAS_HYPHEN
:
NO_HYPHEN
);
}
string
Digit
::
GetFeatureValueName
(
FeatureValue
value
)
const
{
switch
(
value
)
{
case
NO_DIGIT
:
return
"NO_DIGIT"
;
case
SOME_DIGIT
:
return
"SOME_DIGIT"
;
case
ALL_DIGIT
:
return
"ALL_DIGIT"
;
}
return
"<INVALID>"
;
}
FeatureValue
Digit
::
ComputeValue
(
const
Token
&
token
)
const
{
const
string
&
word
=
token
.
word
();
bool
has_digit
=
isdigit
(
word
[
0
]);
bool
all_digit
=
has_digit
;
for
(
size_t
i
=
1
;
i
<
word
.
length
();
++
i
)
{
bool
char_is_digit
=
isdigit
(
word
[
i
]);
all_digit
=
all_digit
&&
char_is_digit
;
has_digit
=
has_digit
||
char_is_digit
;
if
(
!
all_digit
&&
has_digit
)
return
SOME_DIGIT
;
}
if
(
!
all_digit
)
return
NO_DIGIT
;
return
ALL_DIGIT
;
}
AffixTableFeature
::
AffixTableFeature
(
AffixTable
::
Type
type
)
:
type_
(
type
)
{
if
(
type
==
AffixTable
::
PREFIX
)
{
input_name_
=
"prefix-table"
;
}
else
{
input_name_
=
"suffix-table"
;
}
}
AffixTableFeature
::~
AffixTableFeature
()
{
SharedStore
::
Release
(
affix_table_
);
affix_table_
=
nullptr
;
}
string
AffixTableFeature
::
WorkspaceName
()
const
{
return
SharedStoreUtils
::
CreateDefaultName
(
"affix-table"
,
input_name_
,
type_
,
affix_length_
);
}
// Utility function to create a new affix table without changing constructors,
// to be called by the SharedStore.
static
AffixTable
*
CreateAffixTable
(
const
string
&
filename
,
AffixTable
::
Type
type
)
{
AffixTable
*
affix_table
=
new
AffixTable
(
type
,
1
);
tensorflow
::
RandomAccessFile
*
file
;
TF_CHECK_OK
(
tensorflow
::
Env
::
Default
()
->
NewRandomAccessFile
(
filename
,
&
file
));
ProtoRecordReader
reader
(
file
);
affix_table
->
Read
(
&
reader
);
return
affix_table
;
}
void
AffixTableFeature
::
Setup
(
TaskContext
*
context
)
{
context
->
GetInput
(
input_name_
,
"recordio"
,
"affix-table"
);
affix_length_
=
GetIntParameter
(
"length"
,
0
);
CHECK_GE
(
affix_length_
,
0
)
<<
"Length must be specified for affix preprocessor."
;
TokenLookupFeature
::
Setup
(
context
);
}
void
AffixTableFeature
::
Init
(
TaskContext
*
context
)
{
string
filename
=
context
->
InputFile
(
*
context
->
GetInput
(
input_name_
));
// Get the shared AffixTable object.
std
::
function
<
AffixTable
*
()
>
closure
=
std
::
bind
(
CreateAffixTable
,
filename
,
type_
);
affix_table_
=
SharedStore
::
ClosureGetOrDie
(
filename
,
&
closure
);
CHECK_GE
(
affix_table_
->
max_length
(),
affix_length_
)
<<
"Affixes of length "
<<
affix_length_
<<
" needed, but the affix "
<<
"table only provides affixes of length <= "
<<
affix_table_
->
max_length
()
<<
"."
;
TokenLookupFeature
::
Init
(
context
);
}
FeatureValue
AffixTableFeature
::
ComputeValue
(
const
Token
&
token
)
const
{
const
string
&
word
=
token
.
word
();
UnicodeText
text
;
text
.
PointToUTF8
(
word
.
c_str
(),
word
.
size
());
if
(
affix_length_
>
text
.
size
())
return
UnknownValue
();
UnicodeText
::
const_iterator
start
,
end
;
if
(
type_
==
AffixTable
::
PREFIX
)
{
start
=
end
=
text
.
begin
();
for
(
int
i
=
0
;
i
<
affix_length_
;
++
i
)
++
end
;
}
else
{
start
=
end
=
text
.
end
();
for
(
int
i
=
0
;
i
<
affix_length_
;
++
i
)
--
start
;
}
string
affix
(
start
.
utf8_data
(),
end
.
utf8_data
()
-
start
.
utf8_data
());
int
affix_id
=
affix_table_
->
AffixId
(
affix
);
return
affix_id
==
-
1
?
UnknownValue
()
:
affix_id
;
}
string
AffixTableFeature
::
GetFeatureValueName
(
FeatureValue
value
)
const
{
if
(
value
==
UnknownValue
())
return
"<UNKNOWN>"
;
if
(
value
>=
0
&&
value
<
UnknownValue
())
{
return
affix_table_
->
AffixForm
(
value
);
}
LOG
(
ERROR
)
<<
"Invalid feature value: "
<<
value
;
return
"<INVALID>"
;
}
// Registry for the Sentence + token index feature functions.
REGISTER_CLASS_REGISTRY
(
"sentence+index feature function"
,
SentenceFeature
);
// Register the features defined in the header.
REGISTER_SENTENCE_IDX_FEATURE
(
"word"
,
Word
);
REGISTER_SENTENCE_IDX_FEATURE
(
"lcword"
,
LowercaseWord
);
REGISTER_SENTENCE_IDX_FEATURE
(
"tag"
,
Tag
);
REGISTER_SENTENCE_IDX_FEATURE
(
"offset"
,
Offset
);
REGISTER_SENTENCE_IDX_FEATURE
(
"hyphen"
,
Hyphen
);
REGISTER_SENTENCE_IDX_FEATURE
(
"digit"
,
Digit
);
REGISTER_SENTENCE_IDX_FEATURE
(
"prefix"
,
PrefixFeature
);
REGISTER_SENTENCE_IDX_FEATURE
(
"suffix"
,
SuffixFeature
);
}
// namespace syntaxnet
syntaxnet/syntaxnet/sentence_features.h
0 → 100644
View file @
32ab5a58
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Features that operate on Sentence objects. Most features are defined
// in this header so they may be re-used via composition into other more
// advanced feature classes.
#ifndef $TARGETDIR_SENTENCE_FEATURES_H_
#define $TARGETDIR_SENTENCE_FEATURES_H_
#include "syntaxnet/affix.h"
#include "syntaxnet/feature_extractor.h"
#include "syntaxnet/feature_types.h"
#include "syntaxnet/shared_store.h"
#include "syntaxnet/task_context.h"
#include "syntaxnet/workspace.h"
namespace
syntaxnet
{
// Feature function for any component that processes Sentences, whose
// focus is a token index into the sentence.
typedef
FeatureFunction
<
Sentence
,
int
>
SentenceFeature
;
// Alias for Locator type features that take (Sentence, int) signatures
// and call other (Sentence, int) features.
template
<
class
DER
>
using
Locator
=
FeatureLocator
<
DER
,
Sentence
,
int
>
;
class
TokenLookupFeature
:
public
SentenceFeature
{
public:
void
Init
(
TaskContext
*
context
)
override
{
set_feature_type
(
new
ResourceBasedFeatureType
<
TokenLookupFeature
>
(
name
(),
this
,
{{
NumValues
(),
"<OUTSIDE>"
}}));
}
// Given a position in a sentence and workspaces, looks up the corresponding
// feature value. The index is relative to the start of the sentence.
virtual
FeatureValue
ComputeValue
(
const
Token
&
token
)
const
=
0
;
// Number of unique values.
virtual
int64
NumValues
()
const
=
0
;
// Convert the numeric value of the feature to a human readable string.
virtual
string
GetFeatureValueName
(
FeatureValue
value
)
const
=
0
;
// Name of the shared workspace.
virtual
string
WorkspaceName
()
const
=
0
;
// Runs ComputeValue for each token in the sentence.
void
Preprocess
(
WorkspaceSet
*
workspaces
,
Sentence
*
sentence
)
const
override
{
if
(
workspaces
->
Has
<
VectorIntWorkspace
>
(
workspace_
))
return
;
VectorIntWorkspace
*
workspace
=
new
VectorIntWorkspace
(
sentence
->
token_size
());
for
(
int
i
=
0
;
i
<
sentence
->
token_size
();
++
i
)
{
const
int
value
=
ComputeValue
(
sentence
->
token
(
i
));
workspace
->
set_element
(
i
,
value
);
}
workspaces
->
Set
<
VectorIntWorkspace
>
(
workspace_
,
workspace
);
}
// Requests a vector of int's to store in the workspace registry.
void
RequestWorkspaces
(
WorkspaceRegistry
*
registry
)
override
{
workspace_
=
registry
->
Request
<
VectorIntWorkspace
>
(
WorkspaceName
());
}
// Returns the precomputed value, or NumValues() for features outside
// the sentence.
FeatureValue
Compute
(
const
WorkspaceSet
&
workspaces
,
const
Sentence
&
sentence
,
int
focus
,
const
FeatureVector
*
result
)
const
override
{
if
(
focus
<
0
||
focus
>=
sentence
.
token_size
())
return
NumValues
();
return
workspaces
.
Get
<
VectorIntWorkspace
>
(
workspace_
).
element
(
focus
);
}
private:
int
workspace_
;
};
// Lookup feature that uses a TermFrequencyMap to store a string->int mapping.
class
TermFrequencyMapFeature
:
public
TokenLookupFeature
{
public:
explicit
TermFrequencyMapFeature
(
const
string
&
input_name
)
:
input_name_
(
input_name
),
min_freq_
(
0
),
max_num_terms_
(
0
)
{}
~
TermFrequencyMapFeature
()
override
;
// Requests the input map as a resource.
void
Setup
(
TaskContext
*
context
)
override
;
// Loads the input map into memory (using SharedStore to avoid redundancy.)
void
Init
(
TaskContext
*
context
)
override
;
// Number of unique values.
virtual
int64
NumValues
()
const
{
return
term_map_
->
Size
()
+
1
;
}
// Special value for strings not in the map.
FeatureValue
UnknownValue
()
const
{
return
term_map_
->
Size
();
}
// Uses the TermFrequencyMap to lookup the string associated with a value.
string
GetFeatureValueName
(
FeatureValue
value
)
const
override
;
// Name of the shared workspace.
string
WorkspaceName
()
const
override
;
protected:
const
TermFrequencyMap
&
term_map
()
const
{
return
*
term_map_
;
}
private:
// Shortcut pointer to shared map. Not owned.
const
TermFrequencyMap
*
term_map_
=
nullptr
;
// Name of the input for the term map.
string
input_name_
;
// Filename of the underlying resource.
string
file_name_
;
// Minimum frequency for term map.
int
min_freq_
;
// Maximum number of terms for term map.
int
max_num_terms_
;
};
class
Word
:
public
TermFrequencyMapFeature
{
public:
Word
()
:
TermFrequencyMapFeature
(
"word-map"
)
{}
FeatureValue
ComputeValue
(
const
Token
&
token
)
const
override
{
string
form
=
token
.
word
();
return
term_map
().
LookupIndex
(
form
,
UnknownValue
());
}
};
class
LowercaseWord
:
public
TermFrequencyMapFeature
{
public:
LowercaseWord
()
:
TermFrequencyMapFeature
(
"lc-word-map"
)
{}
FeatureValue
ComputeValue
(
const
Token
&
token
)
const
override
{
const
string
lcword
=
utils
::
Lowercase
(
token
.
word
());
return
term_map
().
LookupIndex
(
lcword
,
UnknownValue
());
}
};
class
Tag
:
public
TermFrequencyMapFeature
{
public:
Tag
()
:
TermFrequencyMapFeature
(
"tag-map"
)
{}
FeatureValue
ComputeValue
(
const
Token
&
token
)
const
override
{
return
term_map
().
LookupIndex
(
token
.
tag
(),
UnknownValue
());
}
};
class
Label
:
public
TermFrequencyMapFeature
{
public:
Label
()
:
TermFrequencyMapFeature
(
"label-map"
)
{}
FeatureValue
ComputeValue
(
const
Token
&
token
)
const
override
{
return
term_map
().
LookupIndex
(
token
.
label
(),
UnknownValue
());
}
};
class
LexicalCategoryFeature
:
public
TokenLookupFeature
{
public:
LexicalCategoryFeature
(
const
string
&
name
,
int
cardinality
)
:
name_
(
name
),
cardinality_
(
cardinality
)
{}
~
LexicalCategoryFeature
()
override
{}
FeatureValue
NumValues
()
const
override
{
return
cardinality_
;
}
// Returns the identifier for the workspace for this preprocessor.
string
WorkspaceName
()
const
override
{
return
tensorflow
::
strings
::
StrCat
(
name_
,
":"
,
cardinality_
);
}
private:
// Name of the category type.
const
string
name_
;
// Number of values.
const
int
cardinality_
;
};
// Preprocessor that computes whether a word has a hyphen or not.
class
Hyphen
:
public
LexicalCategoryFeature
{
public:
// Enumeration of values.
enum
Category
{
NO_HYPHEN
=
0
,
HAS_HYPHEN
=
1
,
CARDINALITY
=
2
,
};
// Default constructor.
Hyphen
()
:
LexicalCategoryFeature
(
"hyphen"
,
CARDINALITY
)
{}
// Returns a string representation of the enum value.
string
GetFeatureValueName
(
FeatureValue
value
)
const
override
;
// Returns the category value for the token.
FeatureValue
ComputeValue
(
const
Token
&
token
)
const
override
;
};
// Preprocessor that computes whether a word has a hyphen or not.
class
Digit
:
public
LexicalCategoryFeature
{
public:
// Enumeration of values.
enum
Category
{
NO_DIGIT
=
0
,
SOME_DIGIT
=
1
,
ALL_DIGIT
=
2
,
CARDINALITY
=
3
,
};
// Default constructor.
Digit
()
:
LexicalCategoryFeature
(
"digit"
,
CARDINALITY
)
{}
// Returns a string representation of the enum value.
string
GetFeatureValueName
(
FeatureValue
value
)
const
override
;
// Returns the category value for the token.
FeatureValue
ComputeValue
(
const
Token
&
token
)
const
override
;
};
// TokenLookupPreprocessor object to compute prefixes and suffixes of words. The
// AffixTable is stored in the SharedStore. This is very similar to the
// implementation of TermFrequencyMapPreprocessor, but using an AffixTable to
// perform the lookups. There are only two specializations, for prefixes and
// suffixes.
class
AffixTableFeature
:
public
TokenLookupFeature
{
public:
// Explicit constructor to set the type of the table. This determines the
// requested input.
explicit
AffixTableFeature
(
AffixTable
::
Type
type
);
~
AffixTableFeature
()
override
;
// Requests inputs for the affix table.
void
Setup
(
TaskContext
*
context
)
override
;
// Loads the affix table from the SharedStore.
void
Init
(
TaskContext
*
context
)
override
;
// The workspace name is specific to which affix length we are computing.
string
WorkspaceName
()
const
override
;
// Returns the total number of affixes in the table, regardless of specified
// length.
FeatureValue
NumValues
()
const
override
{
return
affix_table_
->
size
()
+
1
;
}
// Special value for strings not in the map.
FeatureValue
UnknownValue
()
const
{
return
affix_table_
->
size
();
}
// Looks up the affix for a given word.
FeatureValue
ComputeValue
(
const
Token
&
token
)
const
override
;
// Returns the string associated with a value.
string
GetFeatureValueName
(
FeatureValue
value
)
const
override
;
private:
// Size parameter for the affix table.
int
affix_length_
;
// Name of the input for the table.
string
input_name_
;
// The type of the affix table.
const
AffixTable
::
Type
type_
;
// Affix table used for indexing. This comes from the shared store, and is not
// owned directly.
const
AffixTable
*
affix_table_
=
nullptr
;
};
// Specific instantiation for computing prefixes. This requires the input
// "prefix-table".
class
PrefixFeature
:
public
AffixTableFeature
{
public:
PrefixFeature
()
:
AffixTableFeature
(
AffixTable
::
PREFIX
)
{}
};
// Specific instantiation for computing suffixes. Requires the input
// "suffix-table."
class
SuffixFeature
:
public
AffixTableFeature
{
public:
SuffixFeature
()
:
AffixTableFeature
(
AffixTable
::
SUFFIX
)
{}
};
// Offset locator. Simple locator: just changes the focus by some offset.
class
Offset
:
public
Locator
<
Offset
>
{
public:
void
UpdateArgs
(
const
WorkspaceSet
&
workspaces
,
const
Sentence
&
sentence
,
int
*
focus
)
const
{
*
focus
+=
argument
();
}
};
typedef
FeatureExtractor
<
Sentence
,
int
>
SentenceExtractor
;
// Utility to register the sentence_instance::Feature functions.
#define REGISTER_SENTENCE_IDX_FEATURE(name, type) \
REGISTER_FEATURE_FUNCTION(SentenceFeature, name, type)
}
// namespace syntaxnet
#endif // $TARGETDIR_SENTENCE_FEATURES_H_
syntaxnet/syntaxnet/sentence_features_test.cc
0 → 100644
View file @
32ab5a58
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "syntaxnet/sentence_features.h"
#include <string>
#include <vector>
#include <gmock/gmock.h>
#include "syntaxnet/utils.h"
#include "syntaxnet/feature_extractor.h"
#include "syntaxnet/populate_test_inputs.h"
#include "syntaxnet/sentence.pb.h"
#include "syntaxnet/task_context.h"
#include "syntaxnet/task_spec.pb.h"
#include "syntaxnet/workspace.h"
using
testing
::
UnorderedElementsAreArray
;
namespace
syntaxnet
{
// A basic fixture for testing Features. Takes a string of a
// Sentence protobuf that is used as the test data in the constructor.
class
SentenceFeaturesTest
:
public
::
testing
::
Test
{
protected:
explicit
SentenceFeaturesTest
(
const
string
&
prototxt
)
:
sentence_
(
ParseASCII
(
prototxt
)),
creators_
(
PopulateTestInputs
::
Defaults
(
sentence_
))
{}
static
Sentence
ParseASCII
(
const
string
&
prototxt
)
{
Sentence
document
;
CHECK
(
TextFormat
::
ParseFromString
(
prototxt
,
&
document
));
return
document
;
}
// Prepares a new feature for extracting from the attached sentence,
// regenerating the TaskContext and all resources. Will automatically add
// anything in info_ field into the LexiFuse repository.
virtual
void
PrepareFeature
(
const
string
&
fml
)
{
context_
.
mutable_spec
()
->
mutable_input
()
->
Clear
();
context_
.
mutable_spec
()
->
mutable_output
()
->
Clear
();
extractor_
.
reset
(
new
SentenceExtractor
());
extractor_
->
Parse
(
fml
);
extractor_
->
Setup
(
&
context_
);
creators_
.
Populate
(
&
context_
);
extractor_
->
Init
(
&
context_
);
extractor_
->
RequestWorkspaces
(
&
registry_
);
workspaces_
.
Reset
(
registry_
);
extractor_
->
Preprocess
(
&
workspaces_
,
&
sentence_
);
}
// Returns the string representation of the prepared feature extracted at the
// given index.
virtual
string
ExtractFeature
(
int
index
)
{
FeatureVector
result
;
extractor_
->
ExtractFeatures
(
workspaces_
,
sentence_
,
index
,
&
result
);
return
result
.
type
(
0
)
->
GetFeatureValueName
(
result
.
value
(
0
));
}
// Extracts a vector of string representations from evaluating the prepared
// set feature (returning multiple values) at the given index.
virtual
vector
<
string
>
ExtractMultiFeature
(
int
index
)
{
vector
<
string
>
values
;
FeatureVector
result
;
extractor_
->
ExtractFeatures
(
workspaces_
,
sentence_
,
index
,
&
result
);
for
(
int
i
=
0
;
i
<
result
.
size
();
++
i
)
{
values
.
push_back
(
result
.
type
(
i
)
->
GetFeatureValueName
(
result
.
value
(
i
)));
}
return
values
;
}
Sentence
sentence_
;
WorkspaceSet
workspaces_
;
PopulateTestInputs
::
CreatorMap
creators_
;
TaskContext
context_
;
WorkspaceRegistry
registry_
;
std
::
unique_ptr
<
SentenceExtractor
>
extractor_
;
};
// Test fixture for simple common features that operate on just a sentence.
class
CommonSentenceFeaturesTest
:
public
SentenceFeaturesTest
{
protected:
CommonSentenceFeaturesTest
()
:
SentenceFeaturesTest
(
"text: 'I saw a man with a telescope.' "
"token { word: 'I' start: 0 end: 0 tag: 'PRP' category: 'PRON'"
" head: 1 label: 'nsubj' break_level: NO_BREAK } "
"token { word: 'saw' start: 2 end: 4 tag: 'VBD' category: 'VERB'"
" label: 'ROOT' break_level: SPACE_BREAK } "
"token { word: 'a' start: 6 end: 6 tag: 'DT' category: 'DET'"
" head: 3 label: 'det' break_level: SPACE_BREAK } "
"token { word: 'man' start: 8 end: 10 tag: 'NN' category: 'NOUN'"
" head: 1 label: 'dobj' break_level: SPACE_BREAK } "
"token { word: 'with' start: 12 end: 15 tag: 'IN' category: 'ADP'"
" head: 1 label: 'prep' break_level: SPACE_BREAK } "
"token { word: 'a' start: 17 end: 17 tag: 'DT' category: 'DET'"
" head: 6 label: 'det' break_level: SPACE_BREAK } "
"token { word: 'telescope' start: 19 end: 27 tag: 'NN' category: "
"'NOUN'"
" head: 4 label: 'pobj' break_level: SPACE_BREAK } "
"token { word: '.' start: 28 end: 28 tag: '.' category: '.'"
" head: 1 label: 'p' break_level: NO_BREAK }"
)
{}
};
TEST_F
(
CommonSentenceFeaturesTest
,
TagFeature
)
{
PrepareFeature
(
"tag"
);
EXPECT_EQ
(
"<OUTSIDE>"
,
ExtractFeature
(
-
1
));
EXPECT_EQ
(
"PRP"
,
ExtractFeature
(
0
));
EXPECT_EQ
(
"VBD"
,
ExtractFeature
(
1
));
EXPECT_EQ
(
"DT"
,
ExtractFeature
(
2
));
EXPECT_EQ
(
"NN"
,
ExtractFeature
(
3
));
EXPECT_EQ
(
"<OUTSIDE>"
,
ExtractFeature
(
8
));
}
TEST_F
(
CommonSentenceFeaturesTest
,
TagFeaturePassesArgs
)
{
PrepareFeature
(
"tag(min-freq=5)"
);
// don't load any tags
EXPECT_EQ
(
ExtractFeature
(
-
1
),
"<OUTSIDE>"
);
EXPECT_EQ
(
ExtractFeature
(
0
),
"<UNKNOWN>"
);
EXPECT_EQ
(
ExtractFeature
(
8
),
"<OUTSIDE>"
);
// Only 2 features: <UNKNOWN> and <OUTSIDE>.
EXPECT_EQ
(
2
,
extractor_
->
feature_type
(
0
)
->
GetDomainSize
());
}
TEST_F
(
CommonSentenceFeaturesTest
,
OffsetPlusTag
)
{
PrepareFeature
(
"offset(-1).tag(min-freq=2)"
);
EXPECT_EQ
(
"<OUTSIDE>"
,
ExtractFeature
(
-
1
));
EXPECT_EQ
(
"<OUTSIDE>"
,
ExtractFeature
(
0
));
EXPECT_EQ
(
"<UNKNOWN>"
,
ExtractFeature
(
1
));
EXPECT_EQ
(
"<UNKNOWN>"
,
ExtractFeature
(
2
));
EXPECT_EQ
(
"DT"
,
ExtractFeature
(
3
));
// DT, NN are the only freq tags
EXPECT_EQ
(
"NN"
,
ExtractFeature
(
4
));
EXPECT_EQ
(
"<UNKNOWN>"
,
ExtractFeature
(
5
));
EXPECT_EQ
(
"DT"
,
ExtractFeature
(
6
));
EXPECT_EQ
(
"NN"
,
ExtractFeature
(
7
));
EXPECT_EQ
(
"<UNKNOWN>"
,
ExtractFeature
(
8
));
EXPECT_EQ
(
"<OUTSIDE>"
,
ExtractFeature
(
9
));
}
}
// namespace syntaxnet
syntaxnet/syntaxnet/shared_store.cc
0 → 100644
View file @
32ab5a58
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "syntaxnet/shared_store.h"
#include <unordered_map>
#include "tensorflow/core/lib/strings/stringprintf.h"
namespace
syntaxnet
{
SharedStore
::
SharedObjectMap
*
SharedStore
::
shared_object_map_
=
new
SharedObjectMap
;
mutex
SharedStore
::
shared_object_map_mutex_
(
tensorflow
::
LINKER_INITIALIZED
);
SharedStore
::
SharedObjectMap
*
SharedStore
::
shared_object_map
()
{
return
shared_object_map_
;
}
bool
SharedStore
::
Release
(
const
void
*
object
)
{
if
(
object
==
nullptr
)
{
return
true
;
}
mutex_lock
l
(
shared_object_map_mutex_
);
for
(
SharedObjectMap
::
iterator
it
=
shared_object_map
()
->
begin
();
it
!=
shared_object_map
()
->
end
();
++
it
)
{
if
(
it
->
second
.
object
==
object
)
{
// Check the invariant that reference counts are positive. A violation
// likely implies memory corruption.
CHECK_GE
(
it
->
second
.
refcount
,
1
);
it
->
second
.
refcount
--
;
if
(
it
->
second
.
refcount
==
0
)
{
it
->
second
.
delete_callback
();
shared_object_map
()
->
erase
(
it
);
}
return
true
;
}
}
return
false
;
}
void
SharedStore
::
Clear
()
{
mutex_lock
l
(
shared_object_map_mutex_
);
for
(
SharedObjectMap
::
iterator
it
=
shared_object_map
()
->
begin
();
it
!=
shared_object_map
()
->
end
();
++
it
)
{
it
->
second
.
delete_callback
();
}
shared_object_map
()
->
clear
();
}
string
SharedStoreUtils
::
CreateDefaultName
()
{
return
string
();
}
string
SharedStoreUtils
::
ToString
(
const
string
&
input
)
{
return
ToString
(
tensorflow
::
StringPiece
(
input
));
}
string
SharedStoreUtils
::
ToString
(
const
char
*
input
)
{
return
ToString
(
tensorflow
::
StringPiece
(
input
));
}
string
SharedStoreUtils
::
ToString
(
tensorflow
::
StringPiece
input
)
{
return
tensorflow
::
strings
::
StrCat
(
"
\"
"
,
utils
::
CEscape
(
input
.
ToString
()),
"
\"
"
);
}
string
SharedStoreUtils
::
ToString
(
bool
input
)
{
return
input
?
"true"
:
"false"
;
}
string
SharedStoreUtils
::
ToString
(
float
input
)
{
return
tensorflow
::
strings
::
Printf
(
"%af"
,
input
);
}
string
SharedStoreUtils
::
ToString
(
double
input
)
{
return
tensorflow
::
strings
::
Printf
(
"%a"
,
input
);
}
}
// namespace syntaxnet
syntaxnet/syntaxnet/shared_store.h
0 → 100644
View file @
32ab5a58
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Utility for creating read-only objects once and sharing them across threads.
#ifndef $TARGETDIR_SHARED_STORE_H_
#define $TARGETDIR_SHARED_STORE_H_
#include <functional>
#include <string>
#include <typeindex>
#include <unordered_map>
#include <utility>
#include "syntaxnet/utils.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace
syntaxnet
{
class
SharedStore
{
public:
// Returns an existing object with type T and name 'name' if it exists, else
// creates one with "new T(args...)". Note: Objects will be indexed under
// their typeid + name, so names only have to be unique within a given type.
template
<
typename
T
,
typename
...
Args
>
static
const
T
*
Get
(
const
string
&
name
,
Args
&&
...
args
);
// NOLINT(build/c++11)
// Like Get(), but creates the object with "closure->Run()". If the closure
// returns null, we store a null in the SharedStore, but note that Release()
// cannot be used to remove it. This is because Release() finds the object
// by associative lookup, and there may be more than one null value, so we
// don't know which one to release. If the closure returns a duplicate value
// (one that is pointer-equal to an object already in the SharedStore),
// we disregard it and store null instead -- otherwise associative lookup
// would again fail (and the reference counts would be wrong).
template
<
typename
T
>
static
const
T
*
ClosureGet
(
const
string
&
name
,
std
::
function
<
T
*
()
>
*
closure
);
// Like ClosureGet(), but check-fails if ClosureGet() would return null.
template
<
typename
T
>
static
const
T
*
ClosureGetOrDie
(
const
string
&
name
,
std
::
function
<
T
*
()
>
*
closure
);
// Release an object that was acquired by Get(). When its reference count
// hits 0, the object will be deleted. Returns true if the object was found.
// Does nothing and returns true if the object is null.
static
bool
Release
(
const
void
*
object
);
// Delete all objects in the shared store.
static
void
Clear
();
private:
// A shared object.
struct
SharedObject
{
void
*
object
;
std
::
function
<
void
()
>
delete_callback
;
int
refcount
;
SharedObject
(
void
*
o
,
std
::
function
<
void
()
>
d
)
:
object
(
o
),
delete_callback
(
d
),
refcount
(
1
)
{}
};
// A map from keys to shared objects.
typedef
std
::
unordered_map
<
string
,
SharedObject
>
SharedObjectMap
;
// Return the shared object map.
static
SharedObjectMap
*
shared_object_map
();
// Return the string to use for indexing an object in the shared store.
template
<
typename
T
>
static
string
GetSharedKey
(
const
string
&
name
);
// Delete an object of type T.
template
<
typename
T
>
static
void
DeleteObject
(
T
*
object
);
// Add an object to the shared object map. Return the object.
template
<
typename
T
>
static
T
*
StoreObject
(
const
string
&
key
,
T
*
object
);
// Increment the reference count of an object in the map. Return the object.
template
<
typename
T
>
static
T
*
IncrementRefCountOfObject
(
SharedObjectMap
::
iterator
it
);
// Map from keys to shared objects.
static
SharedObjectMap
*
shared_object_map_
;
static
mutex
shared_object_map_mutex_
;
TF_DISALLOW_COPY_AND_ASSIGN
(
SharedStore
);
};
template
<
typename
T
>
string
SharedStore
::
GetSharedKey
(
const
string
&
name
)
{
const
std
::
type_index
id
=
std
::
type_index
(
typeid
(
T
));
return
tensorflow
::
strings
::
StrCat
(
id
.
name
(),
"_"
,
name
);
}
template
<
typename
T
>
void
SharedStore
::
DeleteObject
(
T
*
object
)
{
delete
object
;
}
template
<
typename
T
>
T
*
SharedStore
::
StoreObject
(
const
string
&
key
,
T
*
object
)
{
std
::
function
<
void
()
>
delete_cb
=
std
::
bind
(
SharedStore
::
DeleteObject
<
T
>
,
object
);
SharedObject
so
(
object
,
delete_cb
);
shared_object_map
()
->
insert
(
std
::
make_pair
(
key
,
so
));
return
object
;
}
template
<
typename
T
>
T
*
SharedStore
::
IncrementRefCountOfObject
(
SharedObjectMap
::
iterator
it
)
{
it
->
second
.
refcount
++
;
return
static_cast
<
T
*>
(
it
->
second
.
object
);
}
template
<
typename
T
,
typename
...
Args
>
const
T
*
SharedStore
::
Get
(
const
string
&
name
,
Args
&&
...
args
)
{
// NOLINT(build/c++11)
mutex_lock
l
(
shared_object_map_mutex_
);
const
string
key
=
GetSharedKey
<
T
>
(
name
);
SharedObjectMap
::
iterator
it
=
shared_object_map
()
->
find
(
key
);
return
(
it
==
shared_object_map
()
->
end
())
?
StoreObject
<
T
>
(
key
,
new
T
(
std
::
forward
<
Args
>
(
args
)...))
:
IncrementRefCountOfObject
<
T
>
(
it
);
}
template
<
typename
T
>
const
T
*
SharedStore
::
ClosureGet
(
const
string
&
name
,
std
::
function
<
T
*
()
>
*
closure
)
{
mutex_lock
l
(
shared_object_map_mutex_
);
const
string
key
=
GetSharedKey
<
T
>
(
name
);
SharedObjectMap
::
iterator
it
=
shared_object_map
()
->
find
(
key
);
if
(
it
==
shared_object_map
()
->
end
())
{
// Creates a new object by calling the closure.
T
*
object
=
(
*
closure
)();
if
(
object
==
nullptr
)
{
LOG
(
ERROR
)
<<
"Closure returned a null pointer"
;
}
else
{
for
(
SharedObjectMap
::
iterator
it
=
shared_object_map
()
->
begin
();
it
!=
shared_object_map
()
->
end
();
++
it
)
{
if
(
it
->
second
.
object
==
object
)
{
LOG
(
ERROR
)
<<
"Closure returned duplicate pointer: "
<<
"keys "
<<
it
->
first
<<
" and "
<<
key
;
// Not a memory leak to discard pointer, since we have another copy.
object
=
nullptr
;
break
;
}
}
}
return
StoreObject
<
T
>
(
key
,
object
);
}
else
{
return
IncrementRefCountOfObject
<
T
>
(
it
);
}
}
template
<
typename
T
>
const
T
*
SharedStore
::
ClosureGetOrDie
(
const
string
&
name
,
std
::
function
<
T
*
()
>
*
closure
)
{
const
T
*
object
=
ClosureGet
<
T
>
(
name
,
closure
);
CHECK
(
object
!=
nullptr
);
return
object
;
}
// A collection of utility functions for working with the shared store.
class
SharedStoreUtils
{
public:
// Returns a shared object registered using a default name that is created
// from the constructor args.
//
// NB: This function does not guarantee a one-to-one relationship between
// sets of constructor args and names. See warnings on CreateDefaultName().
// It is the caller's responsibility to ensure that the args provided will
// result in unique names.
template
<
class
T
,
class
...
Args
>
static
const
T
*
GetWithDefaultName
(
Args
&&
...
args
)
{
// NOLINT(build/c++11)
return
SharedStore
::
Get
<
T
>
(
CreateDefaultName
(
std
::
forward
<
Args
>
(
args
)...),
std
::
forward
<
Args
>
(
args
)...);
}
// Returns a string name representing the args. Implemented via a pair of
// overloaded functions to achieve compile-time recursion.
//
// WARNING: It is possible for instances of different types to have the same
// string representation. For example,
//
// CreateDefaultName(1) == CreateDefaultName(1ULL)
//
template
<
class
First
,
class
...
Rest
>
static
string
CreateDefaultName
(
First
&&
first
,
Rest
&&
...
rest
)
{
// NOLINT(build/c++11)
return
tensorflow
::
strings
::
StrCat
(
ToString
<
First
>
(
std
::
forward
<
First
>
(
first
)),
","
,
CreateDefaultName
(
std
::
forward
<
Rest
>
(
rest
)...));
}
static
string
CreateDefaultName
();
private:
// Returns a string representing the input. The generic implementation uses
// StrCat(), and overloads are provided for selected types.
template
<
class
T
>
static
string
ToString
(
T
input
)
{
return
tensorflow
::
strings
::
StrCat
(
input
);
}
static
string
ToString
(
const
string
&
input
);
static
string
ToString
(
const
char
*
input
);
static
string
ToString
(
tensorflow
::
StringPiece
input
);
static
string
ToString
(
bool
input
);
static
string
ToString
(
float
input
);
static
string
ToString
(
double
input
);
TF_DISALLOW_COPY_AND_ASSIGN
(
SharedStoreUtils
);
};
}
// namespace syntaxnet
#endif // $TARGETDIR_SHARED_STORE_H_
syntaxnet/syntaxnet/shared_store_test.cc
0 → 100644
View file @
32ab5a58
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "syntaxnet/shared_store.h"
#include <string>
#include <gmock/gmock.h>
#include "syntaxnet/utils.h"
#include "tensorflow/core/lib/core/threadpool.h"
using
::
testing
::
_
;
namespace
syntaxnet
{
struct
NoArgs
{
NoArgs
()
{
LOG
(
INFO
)
<<
"Calling NoArgs()"
;
}
};
struct
OneArg
{
string
name
;
explicit
OneArg
(
const
string
&
n
)
:
name
(
n
)
{
LOG
(
INFO
)
<<
"Calling OneArg("
<<
name
<<
")"
;
}
};
struct
TwoArgs
{
string
name
;
int
age
;
TwoArgs
(
const
string
&
n
,
int
a
)
:
name
(
n
),
age
(
a
)
{
LOG
(
INFO
)
<<
"Calling TwoArgs("
<<
name
<<
", "
<<
age
<<
")"
;
}
};
struct
Slow
{
string
lengthy
;
Slow
()
{
LOG
(
INFO
)
<<
"Calling Slow()"
;
lengthy
.
assign
(
50
<<
20
,
'L'
);
// 50MB of the letter 'L'
}
};
struct
CountCalls
{
CountCalls
()
{
LOG
(
INFO
)
<<
"Calling CountCalls()"
;
++
constructor_calls
;
}
~
CountCalls
()
{
LOG
(
INFO
)
<<
"Calling ~CountCalls()"
;
++
destructor_calls
;
}
static
void
Reset
()
{
constructor_calls
=
0
;
destructor_calls
=
0
;
}
static
int
constructor_calls
;
static
int
destructor_calls
;
};
int
CountCalls
::
constructor_calls
=
0
;
int
CountCalls
::
destructor_calls
=
0
;
class
PointerSet
{
public:
PointerSet
()
{
}
void
Add
(
const
void
*
p
)
{
mutex_lock
l
(
mu_
);
pointers_
.
insert
(
p
);
}
int
size
()
{
mutex_lock
l
(
mu_
);
return
pointers_
.
size
();
}
private:
mutex
mu_
;
unordered_set
<
const
void
*>
pointers_
;
};
class
SharedStoreTest
:
public
testing
::
Test
{
protected:
~
SharedStoreTest
()
{
// Clear the shared store after each test, otherwise objects created
// in one test may interfere with other tests.
SharedStore
::
Clear
();
}
};
// Verify that we can call constructors with varying numbers and types of args.
TEST_F
(
SharedStoreTest
,
ConstructorArgs
)
{
SharedStore
::
Get
<
NoArgs
>
(
"no args"
);
SharedStore
::
Get
<
OneArg
>
(
"one arg"
,
"Fred"
);
SharedStore
::
Get
<
TwoArgs
>
(
"two args"
,
"Pebbles"
,
2
);
}
// Verify that an object with a given key is created only once.
TEST_F
(
SharedStoreTest
,
Shared
)
{
const
NoArgs
*
ob1
=
SharedStore
::
Get
<
NoArgs
>
(
"first"
);
const
NoArgs
*
ob2
=
SharedStore
::
Get
<
NoArgs
>
(
"second"
);
const
NoArgs
*
ob3
=
SharedStore
::
Get
<
NoArgs
>
(
"first"
);
EXPECT_EQ
(
ob1
,
ob3
);
EXPECT_NE
(
ob1
,
ob2
);
EXPECT_NE
(
ob2
,
ob3
);
}
// Verify that objects with the same name but different types do not collide.
TEST_F
(
SharedStoreTest
,
DifferentTypes
)
{
const
NoArgs
*
ob1
=
SharedStore
::
Get
<
NoArgs
>
(
"same"
);
const
OneArg
*
ob2
=
SharedStore
::
Get
<
OneArg
>
(
"same"
,
"foo"
);
const
TwoArgs
*
ob3
=
SharedStore
::
Get
<
TwoArgs
>
(
"same"
,
"bar"
,
5
);
EXPECT_NE
(
static_cast
<
const
void
*>
(
ob1
),
static_cast
<
const
void
*>
(
ob2
));
EXPECT_NE
(
static_cast
<
const
void
*>
(
ob1
),
static_cast
<
const
void
*>
(
ob3
));
EXPECT_NE
(
static_cast
<
const
void
*>
(
ob2
),
static_cast
<
const
void
*>
(
ob3
));
}
// Factory method to make a OneArg.
OneArg
*
MakeOneArg
(
const
string
&
n
)
{
return
new
OneArg
(
n
);
}
TEST_F
(
SharedStoreTest
,
ClosureGet
)
{
std
::
function
<
OneArg
*
()
>
closure1
=
std
::
bind
(
MakeOneArg
,
"Al"
);
std
::
function
<
OneArg
*
()
>
closure2
=
std
::
bind
(
MakeOneArg
,
"Al"
);
const
OneArg
*
ob1
=
SharedStore
::
ClosureGet
(
"first"
,
&
closure1
);
const
OneArg
*
ob2
=
SharedStore
::
ClosureGet
(
"first"
,
&
closure2
);
EXPECT_EQ
(
"Al"
,
ob1
->
name
);
EXPECT_EQ
(
ob1
,
ob2
);
}
TEST_F
(
SharedStoreTest
,
PermanentCallback
)
{
std
::
function
<
OneArg
*
()
>
closure
=
std
::
bind
(
MakeOneArg
,
"Al"
);
const
OneArg
*
ob1
=
SharedStore
::
ClosureGet
(
"first"
,
&
closure
);
const
OneArg
*
ob2
=
SharedStore
::
ClosureGet
(
"first"
,
&
closure
);
EXPECT_EQ
(
"Al"
,
ob1
->
name
);
EXPECT_EQ
(
ob1
,
ob2
);
}
// Factory method to "make" a NoArgs by simply returning an input pointer.
NoArgs
*
BogusMakeNoArgs
(
NoArgs
*
ob
)
{
return
ob
;
}
// Create a CountCalls object, pretend it failed, and return null.
CountCalls
*
MakeFailedCountCalls
()
{
CountCalls
*
ob
=
new
CountCalls
;
delete
ob
;
return
nullptr
;
}
// Verify that ClosureGet() only calls the closure for a given key once,
// even if the closure fails.
TEST_F
(
SharedStoreTest
,
FailedClosureGet
)
{
CountCalls
::
Reset
();
std
::
function
<
CountCalls
*
()
>
closure1
(
MakeFailedCountCalls
);
std
::
function
<
CountCalls
*
()
>
closure2
(
MakeFailedCountCalls
);
const
CountCalls
*
ob1
=
SharedStore
::
ClosureGet
(
"first"
,
&
closure1
);
const
CountCalls
*
ob2
=
SharedStore
::
ClosureGet
(
"first"
,
&
closure2
);
EXPECT_EQ
(
nullptr
,
ob1
);
EXPECT_EQ
(
nullptr
,
ob2
);
EXPECT_EQ
(
1
,
CountCalls
::
constructor_calls
);
}
typedef
SharedStoreTest
SharedStoreDeathTest
;
TEST_F
(
SharedStoreDeathTest
,
ClosureGetOrDie
)
{
NoArgs
*
empty
=
nullptr
;
std
::
function
<
NoArgs
*
()
>
closure
=
std
::
bind
(
BogusMakeNoArgs
,
empty
);
EXPECT_DEATH
(
SharedStore
::
ClosureGetOrDie
(
"first"
,
&
closure
),
"nullptr"
);
}
TEST_F
(
SharedStoreTest
,
Release
)
{
const
OneArg
*
ob1
=
SharedStore
::
Get
<
OneArg
>
(
"first"
,
"Fred"
);
const
OneArg
*
ob2
=
SharedStore
::
Get
<
OneArg
>
(
"first"
,
"Fred"
);
EXPECT_EQ
(
ob1
,
ob2
);
EXPECT_TRUE
(
SharedStore
::
Release
(
ob1
));
// now refcount = 1
EXPECT_TRUE
(
SharedStore
::
Release
(
ob1
));
// now object is deleted
EXPECT_FALSE
(
SharedStore
::
Release
(
ob1
));
// now object is not found
EXPECT_TRUE
(
SharedStore
::
Release
(
nullptr
));
// release(nullptr) returns true
}
TEST_F
(
SharedStoreTest
,
Clear
)
{
CountCalls
::
Reset
();
SharedStore
::
Get
<
CountCalls
>
(
"first"
);
SharedStore
::
Get
<
CountCalls
>
(
"second"
);
SharedStore
::
Get
<
CountCalls
>
(
"first"
);
// Test that the constructor and destructor are each called exactly once
// for each key in the shared store.
SharedStore
::
Clear
();
EXPECT_EQ
(
2
,
CountCalls
::
constructor_calls
);
EXPECT_EQ
(
2
,
CountCalls
::
destructor_calls
);
}
void
GetSharedObject
(
PointerSet
*
ps
)
{
// Gets a shared object whose constructor takes a long time.
const
Slow
*
ob
=
SharedStore
::
Get
<
Slow
>
(
"first"
);
// Collects the pointer we got. Later, we'll check whether SharedStore
// mistakenly called the constructor more than once.
ps
->
Add
(
static_cast
<
const
void
*>
(
ob
));
}
// If multiple parallel threads all access an object with the same key,
// only one object is created.
TEST_F
(
SharedStoreTest
,
ThreadSafety
)
{
const
int
kNumThreads
=
20
;
tensorflow
::
thread
::
ThreadPool
*
pool
=
new
tensorflow
::
thread
::
ThreadPool
(
tensorflow
::
Env
::
Default
(),
"ThreadSafetyPool"
,
kNumThreads
);
PointerSet
ps
;
for
(
int
i
=
0
;
i
<
kNumThreads
;
++
i
)
{
std
::
function
<
void
()
>
closure
=
std
::
bind
(
GetSharedObject
,
&
ps
);
pool
->
Schedule
(
closure
);
}
// Waits for closures to finish, then delete the pool.
delete
pool
;
// Expects only one object to have been created across all threads.
EXPECT_EQ
(
1
,
ps
.
size
());
}
}
// namespace syntaxnet
syntaxnet/syntaxnet/sparse.proto
0 → 100644
View file @
32ab5a58
// Protocol for passing around sparse sets of features.
syntax
=
"proto2"
;
package
syntaxnet
;
// A sparse set of features.
//
// If using SparseStringToIdTransformer, description is required and id should
// be omitted; otherwise, id is required and description optional.
//
// id, weight, and description fields are all aligned if present (ie, any of
// these that are non-empty should have the same # items). If weight is omitted,
// 1.0 is used.
message
SparseFeatures
{
repeated
uint64
id
=
1
;
repeated
float
weight
=
2
;
repeated
string
description
=
3
;
};
syntaxnet/syntaxnet/structured_graph_builder.py
0 → 100644
View file @
32ab5a58
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Build structured parser models."""
import
tensorflow
as
tf
from
tensorflow.python.ops
import
control_flow_ops
as
cf
from
tensorflow.python.ops
import
state_ops
from
tensorflow.python.ops
import
tensor_array_ops
from
syntaxnet
import
graph_builder
from
syntaxnet.ops
import
gen_parser_ops
tf
.
NoGradient
(
'BeamParseReader'
)
tf
.
NoGradient
(
'BeamParser'
)
tf
.
NoGradient
(
'BeamParserOutput'
)
def
AddCrossEntropy
(
batch_size
,
n
):
"""Adds a cross entropy cost function."""
cross_entropies
=
[]
def
_Pass
():
return
tf
.
constant
(
0
,
dtype
=
tf
.
float32
,
shape
=
[
1
])
for
beam_id
in
range
(
batch_size
):
beam_gold_slot
=
tf
.
reshape
(
tf
.
slice
(
n
[
'gold_slot'
],
[
beam_id
],
[
1
]),
[
1
])
def
_ComputeCrossEntropy
():
"""Adds ops to compute cross entropy of the gold path in a beam."""
# Requires a cast so that UnsortedSegmentSum, in the gradient,
# is happy with the type of its input 'segment_ids', which
# must be int32.
idx
=
tf
.
cast
(
tf
.
reshape
(
tf
.
where
(
tf
.
equal
(
n
[
'beam_ids'
],
beam_id
)),
[
-
1
]),
tf
.
int32
)
beam_scores
=
tf
.
reshape
(
tf
.
gather
(
n
[
'all_path_scores'
],
idx
),
[
1
,
-
1
])
num
=
tf
.
shape
(
idx
)
return
tf
.
nn
.
softmax_cross_entropy_with_logits
(
beam_scores
,
tf
.
expand_dims
(
tf
.
sparse_to_dense
(
beam_gold_slot
,
num
,
[
1.
],
0.
),
0
))
# The conditional here is needed to deal with the last few batches of the
# corpus which can contain -1 in beam_gold_slot for empty batch slots.
cross_entropies
.
append
(
cf
.
cond
(
beam_gold_slot
[
0
]
>=
0
,
_ComputeCrossEntropy
,
_Pass
))
return
{
'cross_entropy'
:
tf
.
div
(
tf
.
add_n
(
cross_entropies
),
batch_size
)}
class
StructuredGraphBuilder
(
graph_builder
.
GreedyParser
):
"""Extends the standard GreedyParser with a CRF objective using a beam.
The constructor takes two additional keyword arguments.
beam_size: the maximum size the beam can grow to.
max_steps: the maximum number of steps in any particular beam.
The model supports batch training with the batch_size argument to the
AddTraining method.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
self
.
_beam_size
=
kwargs
.
pop
(
'beam_size'
,
10
)
self
.
_max_steps
=
kwargs
.
pop
(
'max_steps'
,
25
)
super
(
StructuredGraphBuilder
,
self
).
__init__
(
*
args
,
**
kwargs
)
def
_AddBeamReader
(
self
,
task_context
,
batch_size
,
corpus_name
,
until_all_final
=
False
,
always_start_new_sentences
=
False
):
"""Adds an op capable of reading sentences and parsing them with a beam."""
features
,
state
,
epochs
=
gen_parser_ops
.
beam_parse_reader
(
task_context
=
task_context
,
feature_size
=
self
.
_feature_size
,
beam_size
=
self
.
_beam_size
,
batch_size
=
batch_size
,
corpus_name
=
corpus_name
,
allow_feature_weights
=
self
.
_allow_feature_weights
,
arg_prefix
=
self
.
_arg_prefix
,
continue_until_all_final
=
until_all_final
,
always_start_new_sentences
=
always_start_new_sentences
)
return
{
'state'
:
state
,
'features'
:
features
,
'epochs'
:
epochs
}
def
_BuildSequence
(
self
,
batch_size
,
max_steps
,
features
,
state
,
use_average
=
False
):
"""Adds a sequence of beam parsing steps."""
def
Advance
(
state
,
step
,
scores_array
,
alive
,
alive_steps
,
*
features
):
scores
=
self
.
_BuildNetwork
(
features
,
return_average
=
use_average
)[
'logits'
]
scores_array
=
scores_array
.
write
(
step
,
scores
)
features
,
state
,
alive
=
(
gen_parser_ops
.
beam_parser
(
state
,
scores
,
self
.
_feature_size
))
return
[
state
,
step
+
1
,
scores_array
,
alive
,
alive_steps
+
tf
.
cast
(
alive
,
tf
.
int32
)]
+
list
(
features
)
# args: (state, step, scores_array, alive, alive_steps, *features)
def
KeepGoing
(
*
args
):
return
tf
.
logical_and
(
args
[
1
]
<
max_steps
,
tf
.
reduce_any
(
args
[
3
]))
step
=
tf
.
constant
(
0
,
tf
.
int32
,
[])
scores_array
=
tensor_array_ops
.
TensorArray
(
dtype
=
tf
.
float32
,
size
=
0
,
dynamic_size
=
True
)
alive
=
tf
.
constant
(
True
,
tf
.
bool
,
[
batch_size
])
alive_steps
=
tf
.
constant
(
0
,
tf
.
int32
,
[
batch_size
])
t
=
tf
.
while_loop
(
KeepGoing
,
Advance
,
[
state
,
step
,
scores_array
,
alive
,
alive_steps
]
+
list
(
features
),
parallel_iterations
=
100
)
# Link to the final nodes/values of ops that have passed through While:
return
{
'state'
:
t
[
0
],
'concat_scores'
:
t
[
2
].
concat
(),
'alive'
:
t
[
3
],
'alive_steps'
:
t
[
4
]}
def
AddTraining
(
self
,
task_context
,
batch_size
,
learning_rate
=
0.1
,
decay_steps
=
4000
,
momentum
=
None
,
corpus_name
=
'documents'
):
with
tf
.
name_scope
(
'training'
):
n
=
self
.
training
n
[
'accumulated_alive_steps'
]
=
self
.
_AddVariable
(
[
batch_size
],
tf
.
int32
,
'accumulated_alive_steps'
,
tf
.
zeros_initializer
)
n
.
update
(
self
.
_AddBeamReader
(
task_context
,
batch_size
,
corpus_name
))
# This adds a required 'step' node too:
learning_rate
=
tf
.
constant
(
learning_rate
,
dtype
=
tf
.
float32
)
n
[
'learning_rate'
]
=
self
.
_AddLearningRate
(
learning_rate
,
decay_steps
)
# Call BuildNetwork *only* to set up the params outside of the main loop.
self
.
_BuildNetwork
(
list
(
n
[
'features'
]))
n
.
update
(
self
.
_BuildSequence
(
batch_size
,
self
.
_max_steps
,
n
[
'features'
],
n
[
'state'
]))
flat_concat_scores
=
tf
.
reshape
(
n
[
'concat_scores'
],
[
-
1
])
(
indices_and_paths
,
beams_and_slots
,
n
[
'gold_slot'
],
n
[
'beam_path_scores'
])
=
gen_parser_ops
.
beam_parser_output
(
n
[
'state'
])
n
[
'indices'
]
=
tf
.
reshape
(
tf
.
gather
(
indices_and_paths
,
[
0
]),
[
-
1
])
n
[
'path_ids'
]
=
tf
.
reshape
(
tf
.
gather
(
indices_and_paths
,
[
1
]),
[
-
1
])
n
[
'all_path_scores'
]
=
tf
.
sparse_segment_sum
(
flat_concat_scores
,
n
[
'indices'
],
n
[
'path_ids'
])
n
[
'beam_ids'
]
=
tf
.
reshape
(
tf
.
gather
(
beams_and_slots
,
[
0
]),
[
-
1
])
n
.
update
(
AddCrossEntropy
(
batch_size
,
n
))
if
self
.
_only_train
:
trainable_params
=
{
k
:
v
for
k
,
v
in
self
.
params
.
iteritems
()
if
k
in
self
.
_only_train
}
else
:
trainable_params
=
self
.
params
for
p
in
trainable_params
:
tf
.
logging
.
info
(
'trainable_param: %s'
,
p
)
regularized_params
=
[
tf
.
nn
.
l2_loss
(
p
)
for
k
,
p
in
trainable_params
.
iteritems
()
if
k
.
startswith
(
'weights'
)
or
k
.
startswith
(
'bias'
)]
l2_loss
=
1e-4
*
tf
.
add_n
(
regularized_params
)
if
regularized_params
else
0
n
[
'cost'
]
=
tf
.
add
(
n
[
'cross_entropy'
],
l2_loss
,
name
=
'cost'
)
n
[
'gradients'
]
=
tf
.
gradients
(
n
[
'cost'
],
trainable_params
.
values
())
with
tf
.
control_dependencies
([
n
[
'alive_steps'
]]):
update_accumulators
=
tf
.
group
(
tf
.
assign_add
(
n
[
'accumulated_alive_steps'
],
n
[
'alive_steps'
]))
def
ResetAccumulators
():
return
tf
.
assign
(
n
[
'accumulated_alive_steps'
],
tf
.
zeros
([
batch_size
],
tf
.
int32
))
n
[
'reset_accumulators_func'
]
=
ResetAccumulators
optimizer
=
tf
.
train
.
MomentumOptimizer
(
n
[
'learning_rate'
],
momentum
,
use_locking
=
self
.
_use_locking
)
train_op
=
optimizer
.
minimize
(
n
[
'cost'
],
var_list
=
trainable_params
.
values
())
for
param
in
trainable_params
.
values
():
slot
=
optimizer
.
get_slot
(
param
,
'momentum'
)
self
.
inits
[
slot
.
name
]
=
state_ops
.
init_variable
(
slot
,
tf
.
zeros_initializer
)
self
.
variables
[
slot
.
name
]
=
slot
def
NumericalChecks
():
return
tf
.
group
(
*
[
tf
.
check_numerics
(
param
,
message
=
'Parameter is not finite.'
)
for
param
in
trainable_params
.
values
()
if
param
.
dtype
.
base_dtype
in
[
tf
.
float32
,
tf
.
float64
]])
check_op
=
cf
.
cond
(
tf
.
equal
(
tf
.
mod
(
self
.
GetStep
(),
self
.
_check_every
),
0
),
NumericalChecks
,
tf
.
no_op
)
avg_update_op
=
tf
.
group
(
*
self
.
_averaging
.
values
())
train_ops
=
[
train_op
]
if
self
.
_check_parameters
:
train_ops
.
append
(
check_op
)
if
self
.
_use_averaging
:
train_ops
.
append
(
avg_update_op
)
with
tf
.
control_dependencies
([
update_accumulators
]):
n
[
'train_op'
]
=
tf
.
group
(
*
train_ops
,
name
=
'train_op'
)
n
[
'alive_steps'
]
=
tf
.
identity
(
n
[
'alive_steps'
],
name
=
'alive_steps'
)
return
n
def
AddEvaluation
(
self
,
task_context
,
batch_size
,
evaluation_max_steps
=
300
,
corpus_name
=
None
):
with
tf
.
name_scope
(
'evaluation'
):
n
=
self
.
evaluation
n
.
update
(
self
.
_AddBeamReader
(
task_context
,
batch_size
,
corpus_name
,
until_all_final
=
True
,
always_start_new_sentences
=
True
))
self
.
_BuildNetwork
(
list
(
n
[
'features'
]),
return_average
=
self
.
_use_averaging
)
n
.
update
(
self
.
_BuildSequence
(
batch_size
,
evaluation_max_steps
,
n
[
'features'
],
n
[
'state'
],
use_average
=
self
.
_use_averaging
))
n
[
'eval_metrics'
],
n
[
'documents'
]
=
(
gen_parser_ops
.
beam_eval_output
(
n
[
'state'
]))
return
n
syntaxnet/syntaxnet/syntaxnet.bzl
0 → 100644
View file @
32ab5a58
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
load
(
"@tf//google/protobuf:protobuf.bzl"
,
"cc_proto_library"
)
load
(
"@tf//google/protobuf:protobuf.bzl"
,
"py_proto_library"
)
def
if_cuda
(
a
,
b
=
[]):
return
select
({
"@tf//third_party/gpus/cuda:cuda_crosstool_condition"
:
a
,
"//conditions:default"
:
b
,
})
def
tf_copts
():
return
([
"-fno-exceptions"
,
"-DEIGEN_AVOID_STL_ARRAY"
,]
+
if_cuda
([
"-DGOOGLE_CUDA=1"
])
+
select
({
"@tf//tensorflow:darwin"
:
[],
"//conditions:default"
:
[
"-pthread"
]}))
def
tf_proto_library
(
name
,
srcs
=
[],
has_services
=
False
,
deps
=
[],
visibility
=
None
,
testonly
=
0
,
cc_api_version
=
2
,
go_api_version
=
2
,
java_api_version
=
2
,
py_api_version
=
2
):
native
.
filegroup
(
name
=
name
+
"_proto_srcs"
,
srcs
=
srcs
,
testonly
=
testonly
,)
cc_proto_library
(
name
=
name
,
srcs
=
srcs
,
deps
=
deps
,
cc_libs
=
[
"@tf//google/protobuf:protobuf"
],
protoc
=
"@tf//google/protobuf:protoc"
,
default_runtime
=
"@tf//google/protobuf:protobuf"
,
testonly
=
testonly
,
visibility
=
visibility
,)
def
tf_proto_library_py
(
name
,
srcs
=
[],
deps
=
[],
visibility
=
None
,
testonly
=
0
):
py_proto_library
(
name
=
name
,
srcs
=
srcs
,
srcs_version
=
"PY2AND3"
,
deps
=
deps
,
default_runtime
=
"@tf//google/protobuf:protobuf_python"
,
protoc
=
"@tf//google/protobuf:protoc"
,
visibility
=
visibility
,
testonly
=
testonly
,)
# Given a list of "op_lib_names" (a list of files in the ops directory
# without their .cc extensions), generate a library for that file.
def
tf_gen_op_libs
(
op_lib_names
):
# Make library out of each op so it can also be used to generate wrappers
# for various languages.
for
n
in
op_lib_names
:
native
.
cc_library
(
name
=
n
+
"_op_lib"
,
copts
=
tf_copts
(),
srcs
=
[
"ops/"
+
n
+
".cc"
],
deps
=
([
"@tf//tensorflow/core:framework"
]),
visibility
=
[
"//visibility:public"
],
alwayslink
=
1
,
linkstatic
=
1
,)
# Invoke this rule in .../tensorflow/python to build the wrapper library.
def
tf_gen_op_wrapper_py
(
name
,
out
=
None
,
hidden
=
[],
visibility
=
None
,
deps
=
[],
require_shape_functions
=
False
):
# Construct a cc_binary containing the specified ops.
tool_name
=
"gen_"
+
name
+
"_py_wrappers_cc"
if
not
deps
:
deps
=
[
"//tensorflow/core:"
+
name
+
"_op_lib"
]
native
.
cc_binary
(
name
=
tool_name
,
linkopts
=
[
"-lm"
],
copts
=
tf_copts
(),
linkstatic
=
1
,
# Faster to link this one-time-use binary dynamically
deps
=
([
"@tf//tensorflow/core:framework"
,
"@tf//tensorflow/python:python_op_gen_main"
]
+
deps
),
)
# Invoke the previous cc_binary to generate a python file.
if
not
out
:
out
=
"ops/gen_"
+
name
+
".py"
native
.
genrule
(
name
=
name
+
"_pygenrule"
,
outs
=
[
out
],
tools
=
[
tool_name
],
cmd
=
(
"$(location "
+
tool_name
+
") "
+
","
.
join
(
hidden
)
+
" "
+
(
"1"
if
require_shape_functions
else
"0"
)
+
" > $@"
))
# Make a py_library out of the generated python file.
native
.
py_library
(
name
=
name
,
srcs
=
[
out
],
srcs_version
=
"PY2AND3"
,
visibility
=
visibility
,
deps
=
[
"@tf//tensorflow/python:framework_for_generated_wrappers"
,
],)
syntaxnet/syntaxnet/tagger_transitions.cc
0 → 100644
View file @
32ab5a58
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Tagger transition system.
//
// This transition system has one type of actions:
// - The SHIFT action pushes the next input token to the stack and
// advances to the next input token, assigning a part-of-speech tag to the
// token that was shifted.
//
// The transition system operates with parser actions encoded as integers:
// - A SHIFT action is encoded as number starting from 0.
#include <string>
#include "syntaxnet/parser_state.h"
#include "syntaxnet/parser_transitions.h"
#include "syntaxnet/shared_store.h"
#include "syntaxnet/task_context.h"
#include "syntaxnet/term_frequency_map.h"
#include "syntaxnet/utils.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace
syntaxnet
{
class
TaggerTransitionState
:
public
ParserTransitionState
{
public:
explicit
TaggerTransitionState
(
const
TermFrequencyMap
*
tag_map
,
const
TagToCategoryMap
*
tag_to_category
)
:
tag_map_
(
tag_map
),
tag_to_category_
(
tag_to_category
)
{}
explicit
TaggerTransitionState
(
const
TaggerTransitionState
*
state
)
:
TaggerTransitionState
(
state
->
tag_map_
,
state
->
tag_to_category_
)
{
tag_
=
state
->
tag_
;
gold_tag_
=
state
->
gold_tag_
;
}
// Clones the transition state by returning a new object.
ParserTransitionState
*
Clone
()
const
override
{
return
new
TaggerTransitionState
(
this
);
}
// Reads gold tags for each token.
void
Init
(
ParserState
*
state
)
{
tag_
.
resize
(
state
->
sentence
().
token_size
(),
-
1
);
gold_tag_
.
resize
(
state
->
sentence
().
token_size
(),
-
1
);
for
(
int
pos
=
0
;
pos
<
state
->
sentence
().
token_size
();
++
pos
)
{
int
tag
=
tag_map_
->
LookupIndex
(
state
->
GetToken
(
pos
).
tag
(),
-
1
);
gold_tag_
[
pos
]
=
tag
;
}
}
// Returns the tag assigned to a given token.
int
Tag
(
int
index
)
const
{
DCHECK_GE
(
index
,
0
);
DCHECK_LT
(
index
,
tag_
.
size
());
return
index
==
-
1
?
-
1
:
tag_
[
index
];
}
// Sets this tag on the token at index.
void
SetTag
(
int
index
,
int
tag
)
{
DCHECK_GE
(
index
,
0
);
DCHECK_LT
(
index
,
tag_
.
size
());
tag_
[
index
]
=
tag
;
}
// Returns the gold tag for a given token.
int
GoldTag
(
int
index
)
const
{
DCHECK_GE
(
index
,
-
1
);
DCHECK_LT
(
index
,
gold_tag_
.
size
());
return
index
==
-
1
?
-
1
:
gold_tag_
[
index
];
}
// Returns the string representation of a POS tag, or an empty string
// if the tag is invalid.
string
TagAsString
(
int
tag
)
const
{
if
(
tag
>=
0
&&
tag
<
tag_map_
->
Size
())
{
return
tag_map_
->
GetTerm
(
tag
);
}
return
""
;
}
// Adds transition state specific annotations to the document.
void
AddParseToDocument
(
const
ParserState
&
state
,
bool
rewrite_root_labels
,
Sentence
*
sentence
)
const
override
{
for
(
size_t
i
=
0
;
i
<
tag_
.
size
();
++
i
)
{
Token
*
token
=
sentence
->
mutable_token
(
i
);
token
->
set_tag
(
TagAsString
(
Tag
(
i
)));
token
->
set_category
(
tag_to_category_
->
GetCategory
(
token
->
tag
()));
}
}
// Whether a parsed token should be considered correct for evaluation.
bool
IsTokenCorrect
(
const
ParserState
&
state
,
int
index
)
const
override
{
return
GoldTag
(
index
)
==
Tag
(
index
);
}
// Returns a human readable string representation of this state.
string
ToString
(
const
ParserState
&
state
)
const
override
{
string
str
;
for
(
int
i
=
state
.
StackSize
();
i
>
0
;
--
i
)
{
const
string
&
word
=
state
.
GetToken
(
state
.
Stack
(
i
-
1
)).
word
();
if
(
i
!=
state
.
StackSize
()
-
1
)
str
.
append
(
" "
);
tensorflow
::
strings
::
StrAppend
(
&
str
,
word
,
"["
,
TagAsString
(
Tag
(
state
.
StackSize
()
-
i
)),
"]"
);
}
for
(
int
i
=
state
.
Next
();
i
<
state
.
NumTokens
();
++
i
)
{
tensorflow
::
strings
::
StrAppend
(
&
str
,
" "
,
state
.
GetToken
(
i
).
word
());
}
return
str
;
}
private:
// Currently assigned POS tags for each token in this sentence.
vector
<
int
>
tag_
;
// Gold POS tags from the input document.
vector
<
int
>
gold_tag_
;
// Tag map used for conversions between integer and string representations
// part of speech tags. Not owned.
const
TermFrequencyMap
*
tag_map_
=
nullptr
;
// Tag to category map. Not owned.
const
TagToCategoryMap
*
tag_to_category_
=
nullptr
;
TF_DISALLOW_COPY_AND_ASSIGN
(
TaggerTransitionState
);
};
class
TaggerTransitionSystem
:
public
ParserTransitionSystem
{
public:
~
TaggerTransitionSystem
()
override
{
SharedStore
::
Release
(
tag_map_
);
}
// Determines tag map location.
void
Setup
(
TaskContext
*
context
)
override
{
input_tag_map_
=
context
->
GetInput
(
"tag-map"
,
"text"
,
""
);
input_tag_to_category_
=
context
->
GetInput
(
"tag-to-category"
,
"text"
,
""
);
}
// Reads tag map and tag to category map.
void
Init
(
TaskContext
*
context
)
{
const
string
tag_map_path
=
TaskContext
::
InputFile
(
*
input_tag_map_
);
tag_map_
=
SharedStoreUtils
::
GetWithDefaultName
<
TermFrequencyMap
>
(
tag_map_path
,
0
,
0
);
const
string
tag_to_category_path
=
TaskContext
::
InputFile
(
*
input_tag_to_category_
);
tag_to_category_
=
SharedStoreUtils
::
GetWithDefaultName
<
TagToCategoryMap
>
(
tag_to_category_path
);
}
// The SHIFT action uses the same value as the corresponding action type.
static
ParserAction
ShiftAction
(
int
tag
)
{
return
tag
;
}
// Returns the number of action types.
int
NumActionTypes
()
const
override
{
return
1
;
}
// Returns the number of possible actions.
int
NumActions
(
int
num_labels
)
const
override
{
return
tag_map_
->
Size
();
}
// The default action for a given state is assigning the most frequent tag.
ParserAction
GetDefaultAction
(
const
ParserState
&
state
)
const
override
{
return
ShiftAction
(
0
);
}
// Returns the next gold action for a given state according to the
// underlying annotated sentence.
ParserAction
GetNextGoldAction
(
const
ParserState
&
state
)
const
override
{
if
(
!
state
.
EndOfInput
())
{
return
ShiftAction
(
TransitionState
(
state
).
GoldTag
(
state
.
Next
()));
}
return
ShiftAction
(
0
);
}
// Checks if the action is allowed in a given parser state.
bool
IsAllowedAction
(
ParserAction
action
,
const
ParserState
&
state
)
const
override
{
return
!
state
.
EndOfInput
();
}
// Makes a shift by pushing the next input token on the stack and moving to
// the next position.
void
PerformActionWithoutHistory
(
ParserAction
action
,
ParserState
*
state
)
const
override
{
DCHECK
(
!
state
->
EndOfInput
());
if
(
!
state
->
EndOfInput
())
{
MutableTransitionState
(
state
)
->
SetTag
(
state
->
Next
(),
action
);
state
->
Push
(
state
->
Next
());
state
->
Advance
();
}
}
// We are in a final state when we reached the end of the input and the stack
// is empty.
bool
IsFinalState
(
const
ParserState
&
state
)
const
override
{
return
state
.
EndOfInput
();
}
// Returns a string representation of a parser action.
string
ActionAsString
(
ParserAction
action
,
const
ParserState
&
state
)
const
override
{
return
tensorflow
::
strings
::
StrCat
(
"SHIFT("
,
tag_map_
->
GetTerm
(
action
),
")"
);
}
// No state is deterministic in this transition system.
bool
IsDeterministicState
(
const
ParserState
&
state
)
const
override
{
return
false
;
}
// Returns a new transition state to be used to enhance the parser state.
ParserTransitionState
*
NewTransitionState
(
bool
training_mode
)
const
override
{
return
new
TaggerTransitionState
(
tag_map_
,
tag_to_category_
);
}
// Downcasts the const ParserTransitionState in ParserState to a const
// TaggerTransitionState.
static
const
TaggerTransitionState
&
TransitionState
(
const
ParserState
&
state
)
{
return
*
static_cast
<
const
TaggerTransitionState
*>
(
state
.
transition_state
());
}
// Downcasts the ParserTransitionState in ParserState to an
// TaggerTransitionState.
static
TaggerTransitionState
*
MutableTransitionState
(
ParserState
*
state
)
{
return
static_cast
<
TaggerTransitionState
*>
(
state
->
mutable_transition_state
());
}
// Input for the tag map. Not owned.
TaskInput
*
input_tag_map_
=
nullptr
;
// Tag map used for conversions between integer and string representations
// part of speech tags. Owned through SharedStore.
const
TermFrequencyMap
*
tag_map_
=
nullptr
;
// Input for the tag to category map. Not owned.
TaskInput
*
input_tag_to_category_
=
nullptr
;
// Tag to category map. Owned through SharedStore.
const
TagToCategoryMap
*
tag_to_category_
=
nullptr
;
};
REGISTER_TRANSITION_SYSTEM
(
"tagger"
,
TaggerTransitionSystem
);
}
// namespace syntaxnet
syntaxnet/syntaxnet/tagger_transitions_test.cc
0 → 100644
View file @
32ab5a58
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <string>
#include "syntaxnet/utils.h"
#include "syntaxnet/parser_state.h"
#include "syntaxnet/parser_transitions.h"
#include "syntaxnet/populate_test_inputs.h"
#include "syntaxnet/sentence.pb.h"
#include "syntaxnet/task_context.h"
#include "syntaxnet/task_spec.pb.h"
#include "syntaxnet/term_frequency_map.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
class
TaggerTransitionTest
:
public
::
testing
::
Test
{
public:
TaggerTransitionTest
()
:
transition_system_
(
ParserTransitionSystem
::
Create
(
"tagger"
))
{}
protected:
// Creates a label map and a tag map for testing based on the given
// document and initializes the transition system appropriately.
void
SetUpForDocument
(
const
Sentence
&
document
)
{
input_label_map_
=
context_
.
GetInput
(
"label-map"
,
"text"
,
""
);
input_label_map_
=
context_
.
GetInput
(
"tag-map"
,
"text"
,
""
);
transition_system_
->
Setup
(
&
context_
);
PopulateTestInputs
::
Defaults
(
document
).
Populate
(
&
context_
);
label_map_
.
Load
(
TaskContext
::
InputFile
(
*
input_label_map_
),
0
/* minimum frequency */
,
-
1
/* maximum number of terms */
);
transition_system_
->
Init
(
&
context_
);
}
// Creates a cloned state from a sentence in order to test that cloning
// works correctly for the new parser states.
ParserState
*
NewClonedState
(
Sentence
*
sentence
)
{
ParserState
state
(
sentence
,
transition_system_
->
NewTransitionState
(
true
/* training mode */
),
&
label_map_
);
return
state
.
Clone
();
}
// Performs gold transitions and check that the labels and heads recorded
// in the parser state match gold heads and labels.
void
GoldParse
(
Sentence
*
sentence
)
{
ParserState
*
state
=
NewClonedState
(
sentence
);
LOG
(
INFO
)
<<
"Initial parser state: "
<<
state
->
ToString
();
while
(
!
transition_system_
->
IsFinalState
(
*
state
))
{
ParserAction
action
=
transition_system_
->
GetNextGoldAction
(
*
state
);
EXPECT_TRUE
(
transition_system_
->
IsAllowedAction
(
action
,
*
state
));
LOG
(
INFO
)
<<
"Performing action: "
<<
transition_system_
->
ActionAsString
(
action
,
*
state
);
transition_system_
->
PerformActionWithoutHistory
(
action
,
state
);
LOG
(
INFO
)
<<
"Parser state: "
<<
state
->
ToString
();
}
delete
state
;
}
// Always takes the default action, and verifies that this leads to
// a final state through a sequence of allowed actions.
void
DefaultParse
(
Sentence
*
sentence
)
{
ParserState
*
state
=
NewClonedState
(
sentence
);
LOG
(
INFO
)
<<
"Initial parser state: "
<<
state
->
ToString
();
while
(
!
transition_system_
->
IsFinalState
(
*
state
))
{
ParserAction
action
=
transition_system_
->
GetDefaultAction
(
*
state
);
EXPECT_TRUE
(
transition_system_
->
IsAllowedAction
(
action
,
*
state
));
LOG
(
INFO
)
<<
"Performing action: "
<<
transition_system_
->
ActionAsString
(
action
,
*
state
);
transition_system_
->
PerformActionWithoutHistory
(
action
,
state
);
LOG
(
INFO
)
<<
"Parser state: "
<<
state
->
ToString
();
}
delete
state
;
}
TaskContext
context_
;
TaskInput
*
input_label_map_
=
nullptr
;
TermFrequencyMap
label_map_
;
std
::
unique_ptr
<
ParserTransitionSystem
>
transition_system_
;
};
TEST_F
(
TaggerTransitionTest
,
SingleSentenceDocumentTest
)
{
string
document_text
;
Sentence
document
;
TF_CHECK_OK
(
ReadFileToString
(
tensorflow
::
Env
::
Default
(),
"syntaxnet/testdata/document"
,
&
document_text
));
LOG
(
INFO
)
<<
"see doc
\n
:"
<<
document_text
;
CHECK
(
TextFormat
::
ParseFromString
(
document_text
,
&
document
));
SetUpForDocument
(
document
);
GoldParse
(
&
document
);
DefaultParse
(
&
document
);
}
}
// namespace syntaxnet
syntaxnet/syntaxnet/task_context.cc
0 → 100644
View file @
32ab5a58
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "syntaxnet/task_context.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/env.h"
namespace
syntaxnet
{
namespace
{
const
char
*
const
kShardPrintFormat
=
"%05d"
;
}
// namespace
TaskInput
*
TaskContext
::
GetInput
(
const
string
&
name
)
{
// Return existing input if it exists.
for
(
int
i
=
0
;
i
<
spec_
.
input_size
();
++
i
)
{
if
(
spec_
.
input
(
i
).
name
()
==
name
)
return
spec_
.
mutable_input
(
i
);
}
// Create new input.
TaskInput
*
input
=
spec_
.
add_input
();
input
->
set_name
(
name
);
return
input
;
}
TaskInput
*
TaskContext
::
GetInput
(
const
string
&
name
,
const
string
&
file_format
,
const
string
&
record_format
)
{
TaskInput
*
input
=
GetInput
(
name
);
if
(
!
file_format
.
empty
())
{
bool
found
=
false
;
for
(
int
i
=
0
;
i
<
input
->
file_format_size
();
++
i
)
{
if
(
input
->
file_format
(
i
)
==
file_format
)
found
=
true
;
}
if
(
!
found
)
input
->
add_file_format
(
file_format
);
}
if
(
!
record_format
.
empty
())
{
bool
found
=
false
;
for
(
int
i
=
0
;
i
<
input
->
record_format_size
();
++
i
)
{
if
(
input
->
record_format
(
i
)
==
record_format
)
found
=
true
;
}
if
(
!
found
)
input
->
add_record_format
(
record_format
);
}
return
input
;
}
void
TaskContext
::
SetParameter
(
const
string
&
name
,
const
string
&
value
)
{
// If the parameter already exists update the value.
for
(
int
i
=
0
;
i
<
spec_
.
parameter_size
();
++
i
)
{
if
(
spec_
.
parameter
(
i
).
name
()
==
name
)
{
spec_
.
mutable_parameter
(
i
)
->
set_value
(
value
);
return
;
}
}
// Add new parameter.
TaskSpec
::
Parameter
*
param
=
spec_
.
add_parameter
();
param
->
set_name
(
name
);
param
->
set_value
(
value
);
}
string
TaskContext
::
GetParameter
(
const
string
&
name
)
const
{
// First try to find parameter in task specification.
for
(
int
i
=
0
;
i
<
spec_
.
parameter_size
();
++
i
)
{
if
(
spec_
.
parameter
(
i
).
name
()
==
name
)
return
spec_
.
parameter
(
i
).
value
();
}
// Parameter not found, return empty string.
return
""
;
}
int
TaskContext
::
GetIntParameter
(
const
string
&
name
)
const
{
string
value
=
GetParameter
(
name
);
return
utils
::
ParseUsing
<
int
>
(
value
,
0
,
utils
::
ParseInt32
);
}
int64
TaskContext
::
GetInt64Parameter
(
const
string
&
name
)
const
{
string
value
=
GetParameter
(
name
);
return
utils
::
ParseUsing
<
int64
>
(
value
,
0ll
,
utils
::
ParseInt64
);
}
bool
TaskContext
::
GetBoolParameter
(
const
string
&
name
)
const
{
string
value
=
GetParameter
(
name
);
return
value
==
"true"
;
}
double
TaskContext
::
GetFloatParameter
(
const
string
&
name
)
const
{
string
value
=
GetParameter
(
name
);
return
utils
::
ParseUsing
<
double
>
(
value
,
.0
,
utils
::
ParseDouble
);
}
string
TaskContext
::
Get
(
const
string
&
name
,
const
char
*
defval
)
const
{
// First try to find parameter in task specification.
for
(
int
i
=
0
;
i
<
spec_
.
parameter_size
();
++
i
)
{
if
(
spec_
.
parameter
(
i
).
name
()
==
name
)
return
spec_
.
parameter
(
i
).
value
();
}
// Parameter not found, return default value.
return
defval
;
}
string
TaskContext
::
Get
(
const
string
&
name
,
const
string
&
defval
)
const
{
return
Get
(
name
,
defval
.
c_str
());
}
int
TaskContext
::
Get
(
const
string
&
name
,
int
defval
)
const
{
string
value
=
Get
(
name
,
""
);
return
utils
::
ParseUsing
<
int
>
(
value
,
defval
,
utils
::
ParseInt32
);
}
int64
TaskContext
::
Get
(
const
string
&
name
,
int64
defval
)
const
{
string
value
=
Get
(
name
,
""
);
return
utils
::
ParseUsing
<
int64
>
(
value
,
defval
,
utils
::
ParseInt64
);
}
double
TaskContext
::
Get
(
const
string
&
name
,
double
defval
)
const
{
string
value
=
Get
(
name
,
""
);
return
utils
::
ParseUsing
<
double
>
(
value
,
defval
,
utils
::
ParseDouble
);
}
bool
TaskContext
::
Get
(
const
string
&
name
,
bool
defval
)
const
{
string
value
=
Get
(
name
,
""
);
return
value
.
empty
()
?
defval
:
value
==
"true"
;
}
string
TaskContext
::
InputFile
(
const
TaskInput
&
input
)
{
CHECK_EQ
(
input
.
part_size
(),
1
)
<<
input
.
name
();
return
input
.
part
(
0
).
file_pattern
();
}
bool
TaskContext
::
Supports
(
const
TaskInput
&
input
,
const
string
&
file_format
,
const
string
&
record_format
)
{
// Check file format.
if
(
input
.
file_format_size
()
>
0
)
{
bool
found
=
false
;
for
(
int
i
=
0
;
i
<
input
.
file_format_size
();
++
i
)
{
if
(
input
.
file_format
(
i
)
==
file_format
)
{
found
=
true
;
break
;
}
}
if
(
!
found
)
return
false
;
}
// Check record format.
if
(
input
.
record_format_size
()
>
0
)
{
bool
found
=
false
;
for
(
int
i
=
0
;
i
<
input
.
record_format_size
();
++
i
)
{
if
(
input
.
record_format
(
i
)
==
record_format
)
{
found
=
true
;
break
;
}
}
if
(
!
found
)
return
false
;
}
return
true
;
}
}
// namespace syntaxnet
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment