Commit 32ab5a58 authored by calberti's avatar calberti Committed by Martin Wicke
Browse files

Adding SyntaxNet to tensorflow/models (#63)

parent 148a15fb
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef $TARGETDIR_PROTO_IO_H_
#define $TARGETDIR_PROTO_IO_H_
#include <iostream>
#include <memory>
#include <string>
#include <vector>
#include "syntaxnet/document_format.h"
#include "syntaxnet/feature_extractor.pb.h"
#include "syntaxnet/feature_types.h"
#include "syntaxnet/registry.h"
#include "syntaxnet/sentence.pb.h"
#include "syntaxnet/task_context.h"
#include "syntaxnet/utils.h"
#include "syntaxnet/workspace.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/io/inputbuffer.h"
#include "tensorflow/core/lib/io/record_reader.h"
#include "tensorflow/core/lib/io/record_writer.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
namespace syntaxnet {
// A convenience wrapper to read protos with a RecordReader.
class ProtoRecordReader {
public:
explicit ProtoRecordReader(tensorflow::RandomAccessFile *file)
: file_(file), reader_(new tensorflow::io::RecordReader(file_)) {}
explicit ProtoRecordReader(const string &filename) {
TF_CHECK_OK(
tensorflow::Env::Default()->NewRandomAccessFile(filename, &file_));
reader_.reset(new tensorflow::io::RecordReader(file_));
}
~ProtoRecordReader() {
reader_.reset();
delete file_;
}
template <typename T>
tensorflow::Status Read(T *proto) {
string buffer;
tensorflow::Status status = reader_->ReadRecord(&offset_, &buffer);
if (status.ok()) {
CHECK(proto->ParseFromString(buffer));
return tensorflow::Status::OK();
} else {
return status;
}
}
private:
tensorflow::RandomAccessFile *file_ = nullptr;
uint64 offset_ = 0;
std::unique_ptr<tensorflow::io::RecordReader> reader_;
};
// A convenience wrapper to write protos with a RecordReader.
class ProtoRecordWriter {
public:
explicit ProtoRecordWriter(const string &filename) {
TF_CHECK_OK(tensorflow::Env::Default()->NewWritableFile(filename, &file_));
writer_.reset(new tensorflow::io::RecordWriter(file_));
}
~ProtoRecordWriter() {
writer_.reset();
delete file_;
}
template <typename T>
void Write(const T &proto) {
TF_CHECK_OK(writer_->WriteRecord(proto.SerializeAsString()));
}
private:
tensorflow::WritableFile *file_ = nullptr;
std::unique_ptr<tensorflow::io::RecordWriter> writer_;
};
// A file implementation to read from stdin.
class StdIn : public tensorflow::RandomAccessFile {
public:
StdIn() {}
~StdIn() override {}
// Reads up to n bytes from standard input. Returns `OUT_OF_RANGE` if fewer
// than n bytes were stored in `*result` because of EOF.
tensorflow::Status Read(uint64 offset, size_t n,
tensorflow::StringPiece *result,
char *scratch) const override {
CHECK_EQ(expected_offset_, offset);
if (!eof_) {
string line;
eof_ = !std::getline(std::cin, line);
buffer_.append(line);
buffer_.append("\n");
}
CopyFromBuffer(std::min(buffer_.size(), n), result, scratch);
if (eof_) {
return tensorflow::errors::OutOfRange("End of file reached");
} else {
return tensorflow::Status::OK();
}
}
private:
void CopyFromBuffer(size_t n, tensorflow::StringPiece *result,
char *scratch) const {
memcpy(scratch, buffer_.data(), buffer_.size());
buffer_ = buffer_.substr(n);
result->set(scratch, n);
expected_offset_ += n;
}
mutable bool eof_ = false;
mutable int64 expected_offset_ = 0;
mutable string buffer_;
TF_DISALLOW_COPY_AND_ASSIGN(StdIn);
};
// Reads sentence protos from a text file.
class TextReader {
public:
explicit TextReader(const TaskInput &input) {
CHECK_EQ(input.record_format_size(), 1)
<< "TextReader only supports inputs with one record format: "
<< input.DebugString();
CHECK_EQ(input.part_size(), 1)
<< "TextReader only supports inputs with one part: "
<< input.DebugString();
filename_ = TaskContext::InputFile(input);
format_.reset(DocumentFormat::Create(input.record_format(0)));
Reset();
}
Sentence *Read() {
// Skips emtpy sentences, e.g., blank lines at the beginning of a file or
// commented out blocks.
vector<Sentence *> sentences;
string key, value;
while (sentences.empty() && format_->ReadRecord(buffer_.get(), &value)) {
key = tensorflow::strings::StrCat(filename_, ":", sentence_count_);
format_->ConvertFromString(key, value, &sentences);
CHECK_LE(sentences.size(), 1);
}
if (sentences.empty()) {
// End of file reached.
return nullptr;
} else {
++sentence_count_;
return sentences[0];
}
}
void Reset() {
sentence_count_ = 0;
tensorflow::RandomAccessFile *file;
if (filename_ == "-") {
static const int kInputBufferSize = 8 * 1024; /* bytes */
file = new StdIn();
buffer_.reset(new tensorflow::io::InputBuffer(file, kInputBufferSize));
} else {
static const int kInputBufferSize = 1 * 1024 * 1024; /* bytes */
TF_CHECK_OK(
tensorflow::Env::Default()->NewRandomAccessFile(filename_, &file));
buffer_.reset(new tensorflow::io::InputBuffer(file, kInputBufferSize));
}
}
private:
string filename_;
int sentence_count_ = 0;
std::unique_ptr<tensorflow::io::InputBuffer> buffer_;
std::unique_ptr<DocumentFormat> format_;
};
// Writes sentence protos to a text conll file.
class TextWriter {
public:
explicit TextWriter(const TaskInput &input) {
CHECK_EQ(input.record_format_size(), 1)
<< "TextWriter only supports files with one record format: "
<< input.DebugString();
CHECK_EQ(input.part_size(), 1)
<< "TextWriter only supports files with one part: "
<< input.DebugString();
filename_ = TaskContext::InputFile(input);
format_.reset(DocumentFormat::Create(input.record_format(0)));
if (filename_ != "-") {
TF_CHECK_OK(
tensorflow::Env::Default()->NewWritableFile(filename_, &file_));
}
}
~TextWriter() {
if (file_) {
file_->Close();
delete file_;
}
}
void Write(const Sentence &sentence) {
string key, value;
format_->ConvertToString(sentence, &key, &value);
if (file_) {
TF_CHECK_OK(file_->Append(value));
} else {
std::cout << value;
}
}
private:
string filename_;
std::unique_ptr<DocumentFormat> format_;
tensorflow::WritableFile *file_ = nullptr;
};
} // namespace syntaxnet
#endif // $TARGETDIR_PROTO_IO_H_
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <math.h>
#include <deque>
#include <unordered_map>
#include <memory>
#include <string>
#include <vector>
#include "syntaxnet/base.h"
#include "syntaxnet/feature_extractor.h"
#include "syntaxnet/parser_state.h"
#include "syntaxnet/parser_transitions.h"
#include "syntaxnet/sentence_batch.h"
#include "syntaxnet/sentence.pb.h"
#include "syntaxnet/shared_store.h"
#include "syntaxnet/sparse.pb.h"
#include "syntaxnet/task_context.h"
#include "syntaxnet/task_spec.pb.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/io/inputbuffer.h"
#include "tensorflow/core/lib/io/table.h"
#include "tensorflow/core/lib/io/table_options.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
using tensorflow::DEVICE_CPU;
using tensorflow::DT_FLOAT;
using tensorflow::DT_INT32;
using tensorflow::DT_INT64;
using tensorflow::DT_STRING;
using tensorflow::DataType;
using tensorflow::OpKernel;
using tensorflow::OpKernelConstruction;
using tensorflow::OpKernelContext;
using tensorflow::Tensor;
using tensorflow::TensorShape;
using tensorflow::error::OUT_OF_RANGE;
using tensorflow::errors::InvalidArgument;
namespace syntaxnet {
class ParsingReader : public OpKernel {
public:
explicit ParsingReader(OpKernelConstruction *context) : OpKernel(context) {
string file_path, corpus_name;
OP_REQUIRES_OK(context, context->GetAttr("task_context", &file_path));
OP_REQUIRES_OK(context, context->GetAttr("feature_size", &feature_size_));
OP_REQUIRES_OK(context, context->GetAttr("batch_size", &max_batch_size_));
OP_REQUIRES_OK(context, context->GetAttr("corpus_name", &corpus_name));
OP_REQUIRES_OK(context, context->GetAttr("arg_prefix", &arg_prefix_));
// Reads task context from file.
string data;
OP_REQUIRES_OK(context, ReadFileToString(tensorflow::Env::Default(),
file_path, &data));
OP_REQUIRES(context,
TextFormat::ParseFromString(data, task_context_.mutable_spec()),
InvalidArgument("Could not parse task context at ", file_path));
// Set up the batch reader.
sentence_batch_.reset(
new SentenceBatch(max_batch_size_, corpus_name));
sentence_batch_->Init(&task_context_);
// Set up the parsing features and transition system.
states_.resize(max_batch_size_);
workspaces_.resize(max_batch_size_);
features_.reset(new ParserEmbeddingFeatureExtractor(arg_prefix_));
features_->Setup(&task_context_);
transition_system_.reset(ParserTransitionSystem::Create(task_context_.Get(
features_->GetParamName("transition_system"), "arc-standard")));
transition_system_->Setup(&task_context_);
features_->Init(&task_context_);
features_->RequestWorkspaces(&workspace_registry_);
transition_system_->Init(&task_context_);
string label_map_path =
TaskContext::InputFile(*task_context_.GetInput("label-map"));
label_map_ = SharedStoreUtils::GetWithDefaultName<TermFrequencyMap>(
label_map_path, 0, 0);
// Checks number of feature groups matches the task context.
const int required_size = features_->embedding_dims().size();
OP_REQUIRES(
context, feature_size_ == required_size,
InvalidArgument("Task context requires feature_size=", required_size));
}
~ParsingReader() override { SharedStore::Release(label_map_); }
// Creates a new ParserState if there's another sentence to be read.
virtual void AdvanceSentence(int index) {
states_[index].reset();
if (sentence_batch_->AdvanceSentence(index)) {
states_[index].reset(new ParserState(
sentence_batch_->sentence(index),
transition_system_->NewTransitionState(true), label_map_));
workspaces_[index].Reset(workspace_registry_);
features_->Preprocess(&workspaces_[index], states_[index].get());
}
}
void Compute(OpKernelContext *context) override {
mutex_lock lock(mu_);
// Advances states to the next positions.
PerformActions(context);
// Advances any final states to the next sentences.
for (int i = 0; i < max_batch_size_; ++i) {
if (state(i) == nullptr) continue;
// Switches to the next sentence if we're at a final state.
while (transition_system_->IsFinalState(*state(i))) {
VLOG(2) << "Advancing sentence " << i;
AdvanceSentence(i);
if (state(i) == nullptr) break; // EOF has been reached
}
}
// Rewinds if no states remain in the batch (we need to re-wind the corpus).
if (sentence_batch_->size() == 0) {
++num_epochs_;
LOG(INFO) << "Starting epoch " << num_epochs_;
sentence_batch_->Rewind();
for (int i = 0; i < max_batch_size_; ++i) AdvanceSentence(i);
}
// Create the outputs for each feature space.
vector<Tensor *> feature_outputs(features_->NumEmbeddings());
for (size_t i = 0; i < feature_outputs.size(); ++i) {
OP_REQUIRES_OK(context, context->allocate_output(
i, TensorShape({sentence_batch_->size(),
features_->FeatureSize(i)}),
&feature_outputs[i]));
}
// Populate feature outputs.
for (int i = 0, index = 0; i < max_batch_size_; ++i) {
if (states_[i] == nullptr) continue;
// Extract features from the current parser state, and fill up the
// available batch slots.
std::vector<std::vector<SparseFeatures>> features =
features_->ExtractSparseFeatures(workspaces_[i], *states_[i]);
for (size_t feature_space = 0; feature_space < features.size();
++feature_space) {
int feature_size = features[feature_space].size();
CHECK(feature_size == features_->FeatureSize(feature_space));
auto features_output = feature_outputs[feature_space]->matrix<string>();
for (int k = 0; k < feature_size; ++k) {
features_output(index, k) =
features[feature_space][k].SerializeAsString();
}
}
++index;
}
// Return the number of epochs.
Tensor *epoch_output;
OP_REQUIRES_OK(context, context->allocate_output(
feature_size_, TensorShape({}), &epoch_output));
auto num_epochs = epoch_output->scalar<int32>();
num_epochs() = num_epochs_;
// Create outputs specific to this reader.
AddAdditionalOutputs(context);
}
protected:
// Peforms any relevant actions on the parser states, typically either
// the gold action or a predicted action from decoding.
virtual void PerformActions(OpKernelContext *context) = 0;
// Adds outputs specific to this reader starting at additional_output_index().
virtual void AddAdditionalOutputs(OpKernelContext *context) const = 0;
// Returns the output type specification of the this base class.
std::vector<DataType> default_outputs() const {
std::vector<DataType> output_types(feature_size_, DT_STRING);
output_types.push_back(DT_INT32);
return output_types;
}
// Accessors.
int max_batch_size() const { return max_batch_size_; }
int batch_size() const { return sentence_batch_->size(); }
int additional_output_index() const { return feature_size_ + 1; }
ParserState *state(int i) const { return states_[i].get(); }
const ParserTransitionSystem &transition_system() const {
return *transition_system_.get();
}
// Parser task context.
const TaskContext &task_context() const { return task_context_; }
const string &arg_prefix() const { return arg_prefix_; }
private:
// Task context used to configure this op.
TaskContext task_context_;
// Prefix for context parameters.
string arg_prefix_;
// mutex to synchronize access to Compute.
mutex mu_;
// How many times the document source has been rewinded.
int num_epochs_ = 0;
// How many sentences this op can be processing at any given time.
int max_batch_size_ = 1;
// Number of feature groups in the brain parser features.
int feature_size_ = -1;
// Batch of sentences, and the corresponding parser states.
std::unique_ptr<SentenceBatch> sentence_batch_;
// Batch: ParserState objects.
std::vector<std::unique_ptr<ParserState>> states_;
// Batch: WorkspaceSet objects.
std::vector<WorkspaceSet> workspaces_;
// Dependency label map used in transition system.
const TermFrequencyMap *label_map_;
// Transition system.
std::unique_ptr<ParserTransitionSystem> transition_system_;
// Typed feature extractor for embeddings.
std::unique_ptr<ParserEmbeddingFeatureExtractor> features_;
// Internal workspace registry for use in feature extraction.
WorkspaceRegistry workspace_registry_;
TF_DISALLOW_COPY_AND_ASSIGN(ParsingReader);
};
class GoldParseReader : public ParsingReader {
public:
explicit GoldParseReader(OpKernelConstruction *context)
: ParsingReader(context) {
// Sets up number and type of inputs and outputs.
std::vector<DataType> output_types = default_outputs();
output_types.push_back(DT_INT32);
OP_REQUIRES_OK(context, context->MatchSignature({}, output_types));
}
private:
// Always performs the next gold action for each state.
void PerformActions(OpKernelContext *context) override {
for (int i = 0; i < max_batch_size(); ++i) {
if (state(i) != nullptr) {
transition_system().PerformAction(
transition_system().GetNextGoldAction(*state(i)), state(i));
}
}
}
// Adds the list of gold actions for each state as an additional output.
void AddAdditionalOutputs(OpKernelContext *context) const override {
Tensor *actions_output;
OP_REQUIRES_OK(context, context->allocate_output(
additional_output_index(),
TensorShape({batch_size()}), &actions_output));
// Add all gold actions for non-null states as an additional output.
auto gold_actions = actions_output->vec<int32>();
for (int i = 0, batch_index = 0; i < max_batch_size(); ++i) {
if (state(i) != nullptr) {
const int gold_action =
transition_system().GetNextGoldAction(*state(i));
gold_actions(batch_index++) = gold_action;
}
}
}
TF_DISALLOW_COPY_AND_ASSIGN(GoldParseReader);
};
REGISTER_KERNEL_BUILDER(Name("GoldParseReader").Device(DEVICE_CPU),
GoldParseReader);
// DecodedParseReader parses sentences using transition scores computed
// by a TensorFlow network. This op additionally computes a token correctness
// evaluation metric which can be used to select hyperparameter settings and
// training stopping point.
//
// The notion of correct token is determined by the transition system, e.g.
// a tagger will return POS tag accuracy, while an arc-standard parser will
// return UAS.
//
// Which tokens should be scored is controlled by the '<arg_prefix>_scoring'
// task parameter. Possible values are
// - 'default': skips tokens with only punctuation in the tag name.
// - 'conllx': skips tokens with only punctuation in the surface form.
// - 'ignore_parens': same as conllx, but skipping parentheses as well.
// - '': scores all tokens.
class DecodedParseReader : public ParsingReader {
public:
explicit DecodedParseReader(OpKernelConstruction *context)
: ParsingReader(context) {
// Sets up number and type of inputs and outputs.
std::vector<DataType> output_types = default_outputs();
output_types.push_back(DT_INT32);
output_types.push_back(DT_STRING);
OP_REQUIRES_OK(context, context->MatchSignature({DT_FLOAT}, output_types));
// Gets scoring parameters.
scoring_type_ = task_context().Get(
tensorflow::strings::StrCat(arg_prefix(), "_scoring"), "");
}
private:
void AdvanceSentence(int index) override {
ParsingReader::AdvanceSentence(index);
if (state(index)) {
docids_.push_front(state(index)->sentence().docid());
}
}
// Tallies the # of correct and incorrect tokens for a given ParserState.
void ComputeTokenAccuracy(const ParserState &state) {
for (int i = 0; i < state.sentence().token_size(); ++i) {
const Token &token = state.GetToken(i);
if (utils::PunctuationUtil::ScoreToken(token.word(), token.tag(),
scoring_type_)) {
++num_tokens_;
if (state.IsTokenCorrect(i)) ++num_correct_;
}
}
}
// Performs the allowed action with the highest score on the given state.
// Also records the accuracy whenver a terminal action is taken.
void PerformActions(OpKernelContext *context) override {
auto scores_matrix = context->input(0).matrix<float>();
num_tokens_ = 0;
num_correct_ = 0;
for (int i = 0, batch_index = 0; i < max_batch_size(); ++i) {
ParserState *state = this->state(i);
if (state != nullptr) {
int best_action = 0;
float best_score = -INFINITY;
for (int action = 0; action < scores_matrix.dimension(1); ++action) {
float score = scores_matrix(batch_index, action);
if (score > best_score &&
transition_system().IsAllowedAction(action, *state)) {
best_action = action;
best_score = score;
}
}
transition_system().PerformAction(best_action, state);
// Update the # of scored correct tokens if this is the last state
// in the sentence and save the annotated document.
if (transition_system().IsFinalState(*state)) {
ComputeTokenAccuracy(*state);
sentence_map_[state->sentence().docid()] = state->sentence();
state->AddParseToDocument(&sentence_map_[state->sentence().docid()]);
}
++batch_index;
}
}
}
// Adds the evaluation metrics and annotated documents as additional outputs,
// if there were any terminal states.
void AddAdditionalOutputs(OpKernelContext *context) const override {
Tensor *counts_output;
OP_REQUIRES_OK(context,
context->allocate_output(additional_output_index(),
TensorShape({2}), &counts_output));
auto eval_metrics = counts_output->vec<int32>();
eval_metrics(0) = num_tokens_;
eval_metrics(1) = num_correct_;
// Output annotated documents for each state. To preserve order, repeatedly
// pull from the back of the docids queue as long as the sentences have been
// completely processed. If the next document has not been completely
// processed yet, then the docid will not be found in 'sentence_map_'.
vector<Sentence> sentences;
while (!docids_.empty() &&
sentence_map_.find(docids_.back()) != sentence_map_.end()) {
sentences.emplace_back(sentence_map_[docids_.back()]);
sentence_map_.erase(docids_.back());
docids_.pop_back();
}
Tensor *annotated_output;
OP_REQUIRES_OK(context,
context->allocate_output(
additional_output_index() + 1,
TensorShape({static_cast<int64>(sentences.size())}),
&annotated_output));
auto document_output = annotated_output->vec<string>();
for (size_t i = 0; i < sentences.size(); ++i) {
document_output(i) = sentences[i].SerializeAsString();
}
}
// State for eval metric computation.
int num_tokens_ = 0;
int num_correct_ = 0;
// Parameter for deciding which tokens to score.
string scoring_type_;
mutable std::deque<string> docids_;
mutable map<string, Sentence> sentence_map_;
TF_DISALLOW_COPY_AND_ASSIGN(DecodedParseReader);
};
REGISTER_KERNEL_BUILDER(Name("DecodedParseReader").Device(DEVICE_CPU),
DecodedParseReader);
class WordEmbeddingInitializer : public OpKernel {
public:
explicit WordEmbeddingInitializer(OpKernelConstruction *context)
: OpKernel(context) {
string file_path, data;
OP_REQUIRES_OK(context, context->GetAttr("task_context", &file_path));
OP_REQUIRES_OK(context, ReadFileToString(tensorflow::Env::Default(),
file_path, &data));
OP_REQUIRES(context,
TextFormat::ParseFromString(data, task_context_.mutable_spec()),
InvalidArgument("Could not parse task context at ", file_path));
OP_REQUIRES_OK(context, context->GetAttr("vectors", &vectors_path_));
OP_REQUIRES_OK(context,
context->GetAttr("embedding_init", &embedding_init_));
// Sets up number and type of inputs and outputs.
OP_REQUIRES_OK(context, context->MatchSignature({}, {DT_FLOAT}));
}
void Compute(OpKernelContext *context) override {
// Loads words from vocabulary with mapping to ids.
string path = TaskContext::InputFile(*task_context_.GetInput("word-map"));
const TermFrequencyMap *word_map =
SharedStoreUtils::GetWithDefaultName<TermFrequencyMap>(path, 0, 0);
unordered_map<string, int64> vocab;
for (int i = 0; i < word_map->Size(); ++i) {
vocab[word_map->GetTerm(i)] = i;
}
// Creates a reader pointing to a local copy of the vectors recordio.
string tmp_vectors_path;
OP_REQUIRES_OK(context, CopyToTmpPath(vectors_path_, &tmp_vectors_path));
ProtoRecordReader reader(tmp_vectors_path);
// Loads the embedding vectors into a matrix.
Tensor *embedding_matrix = nullptr;
TokenEmbedding embedding;
while (reader.Read(&embedding) == tensorflow::Status::OK()) {
if (embedding_matrix == nullptr) {
const int embedding_size = embedding.vector().values_size();
OP_REQUIRES_OK(
context, context->allocate_output(
0, TensorShape({word_map->Size() + 3, embedding_size}),
&embedding_matrix));
embedding_matrix->matrix<float>()
.setRandom<Eigen::internal::NormalRandomGenerator<float>>();
embedding_matrix->matrix<float>() =
embedding_matrix->matrix<float>() * static_cast<float>(
embedding_init_ / sqrt(embedding_size));
}
if (vocab.find(embedding.token()) != vocab.end()) {
SetNormalizedRow(embedding.vector(), vocab[embedding.token()],
embedding_matrix);
}
}
}
private:
// Sets embedding_matrix[row] to a normalized version of the given vector.
void SetNormalizedRow(const TokenEmbedding::Vector &vector, const int row,
Tensor *embedding_matrix) {
float norm = 0.0f;
for (int col = 0; col < vector.values_size(); ++col) {
float val = vector.values(col);
norm += val * val;
}
norm = sqrt(norm);
for (int col = 0; col < vector.values_size(); ++col) {
embedding_matrix->matrix<float>()(row, col) = vector.values(col) / norm;
}
}
// Copies the file at source_path to a temporary file and sets tmp_path to the
// temporary file's location. This is helpful since reading from non local
// files with a record reader can be very slow.
static tensorflow::Status CopyToTmpPath(const string &source_path,
string *tmp_path) {
// Opens source file.
tensorflow::RandomAccessFile *source_file;
TF_RETURN_IF_ERROR(tensorflow::Env::Default()->NewRandomAccessFile(
source_path, &source_file));
std::unique_ptr<tensorflow::RandomAccessFile> source_file_deleter(
source_file);
// Creates destination file.
tensorflow::WritableFile *target_file;
*tmp_path = tensorflow::strings::Printf(
"/tmp/%d.%lld", getpid(), tensorflow::Env::Default()->NowMicros());
TF_RETURN_IF_ERROR(
tensorflow::Env::Default()->NewWritableFile(*tmp_path, &target_file));
std::unique_ptr<tensorflow::WritableFile> target_file_deleter(target_file);
// Performs copy.
tensorflow::Status s;
const size_t kBytesToRead = 10 << 20; // 10MB at a time.
string scratch;
scratch.resize(kBytesToRead);
for (uint64 offset = 0; s.ok(); offset += kBytesToRead) {
tensorflow::StringPiece data;
s.Update(source_file->Read(offset, kBytesToRead, &data, &scratch[0]));
target_file->Append(data);
}
if (s.code() == OUT_OF_RANGE) {
return tensorflow::Status::OK();
} else {
return s;
}
}
// Task context used to configure this op.
TaskContext task_context_;
// Embedding vectors that are not found in the input sstable are initialized
// randomly from a normal distribution with zero mean and
// std dev = embedding_init_ / sqrt(embedding_size).
float embedding_init_ = 1.f;
// Path to recordio with word embedding vectors.
string vectors_path_;
};
REGISTER_KERNEL_BUILDER(Name("WordEmbeddingInitializer").Device(DEVICE_CPU),
WordEmbeddingInitializer);
} // namespace syntaxnet
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for reader_ops."""
import os.path
import numpy as np
import tensorflow as tf
from tensorflow.python.framework import test_util
from tensorflow.python.ops import control_flow_ops as cf
from tensorflow.python.platform import googletest
from tensorflow.python.platform import logging
from syntaxnet import dictionary_pb2
from syntaxnet import graph_builder
from syntaxnet import sparse_pb2
from syntaxnet.ops import gen_parser_ops
FLAGS = tf.app.flags.FLAGS
if not hasattr(FLAGS, 'test_srcdir'):
FLAGS.test_srcdir = ''
if not hasattr(FLAGS, 'test_tmpdir'):
FLAGS.test_tmpdir = tf.test.get_temp_dir()
class ParsingReaderOpsTest(test_util.TensorFlowTestCase):
def setUp(self):
# Creates a task context with the correct testing paths.
initial_task_context = os.path.join(
FLAGS.test_srcdir,
'syntaxnet/'
'testdata/context.pbtxt')
self._task_context = os.path.join(FLAGS.test_tmpdir, 'context.pbtxt')
with open(initial_task_context, 'r') as fin:
with open(self._task_context, 'w') as fout:
fout.write(fin.read().replace('SRCDIR', FLAGS.test_srcdir)
.replace('OUTPATH', FLAGS.test_tmpdir))
# Creates necessary term maps.
with self.test_session() as sess:
gen_parser_ops.lexicon_builder(task_context=self._task_context,
corpus_name='training-corpus').run()
self._num_features, self._num_feature_ids, _, self._num_actions = (
sess.run(gen_parser_ops.feature_size(task_context=self._task_context,
arg_prefix='brain_parser')))
def GetMaxId(self, sparse_features):
max_id = 0
for x in sparse_features:
for y in x:
f = sparse_pb2.SparseFeatures()
f.ParseFromString(y)
for i in f.id:
max_id = max(i, max_id)
return max_id
def testParsingReaderOp(self):
# Runs the reader over the test input for two epochs.
num_steps_a = 0
num_actions = 0
num_word_ids = 0
num_tag_ids = 0
num_label_ids = 0
batch_size = 10
with self.test_session() as sess:
(words, tags, labels), epochs, gold_actions = (
gen_parser_ops.gold_parse_reader(self._task_context,
3,
batch_size,
corpus_name='training-corpus'))
while True:
tf_gold_actions, tf_epochs, tf_words, tf_tags, tf_labels = (
sess.run([gold_actions, epochs, words, tags, labels]))
num_steps_a += 1
num_actions = max(num_actions, max(tf_gold_actions) + 1)
num_word_ids = max(num_word_ids, self.GetMaxId(tf_words) + 1)
num_tag_ids = max(num_tag_ids, self.GetMaxId(tf_tags) + 1)
num_label_ids = max(num_label_ids, self.GetMaxId(tf_labels) + 1)
self.assertIn(tf_epochs, [0, 1, 2])
if tf_epochs > 1:
break
# Runs the reader again, this time with a lot of added graph nodes.
num_steps_b = 0
with self.test_session() as sess:
num_features = [6, 6, 4]
num_feature_ids = [num_word_ids, num_tag_ids, num_label_ids]
embedding_sizes = [8, 8, 8]
hidden_layer_sizes = [32, 32]
# Here we aim to test the iteration of the reader op in a complex network,
# not the GraphBuilder.
parser = graph_builder.GreedyParser(
num_actions, num_features, num_feature_ids, embedding_sizes,
hidden_layer_sizes)
parser.AddTraining(self._task_context,
batch_size,
corpus_name='training-corpus')
sess.run(parser.inits.values())
while True:
tf_epochs, tf_cost, _ = sess.run(
[parser.training['epochs'], parser.training['cost'],
parser.training['train_op']])
num_steps_b += 1
self.assertGreaterEqual(tf_cost, 0)
self.assertIn(tf_epochs, [0, 1, 2])
if tf_epochs > 1:
break
# Assert that the two runs made the exact same number of steps.
logging.info('Number of steps in the two runs: %d, %d',
num_steps_a, num_steps_b)
self.assertEqual(num_steps_a, num_steps_b)
def testParsingReaderOpWhileLoop(self):
feature_size = 3
batch_size = 5
def ParserEndpoints():
return gen_parser_ops.gold_parse_reader(self._task_context,
feature_size,
batch_size,
corpus_name='training-corpus')
with self.test_session() as sess:
# The 'condition' and 'body' functions expect as many arguments as there
# are loop variables. 'condition' depends on the 'epoch' loop variable
# only, so we disregard the remaining unused function arguments. 'body'
# returns a list of updated loop variables.
def Condition(epoch, *unused_args):
return tf.less(epoch, 2)
def Body(epoch, num_actions, *feature_args):
# By adding one of the outputs of the reader op ('epoch') as a control
# dependency to the reader op we force the repeated evaluation of the
# reader op.
with epoch.graph.control_dependencies([epoch]):
features, epoch, gold_actions = ParserEndpoints()
num_actions = tf.maximum(num_actions,
tf.reduce_max(gold_actions, [0], False) + 1)
feature_ids = []
for i in range(len(feature_args)):
feature_ids.append(features[i])
return [epoch, num_actions] + feature_ids
epoch = ParserEndpoints()[-2]
num_actions = tf.constant(0)
loop_vars = [epoch, num_actions]
res = sess.run(
cf.While(Condition, Body, loop_vars, parallel_iterations=1))
logging.info('Result: %s', res)
self.assertEqual(res[0], 2)
def testWordEmbeddingInitializer(self):
def _TokenEmbedding(token, embedding):
e = dictionary_pb2.TokenEmbedding()
e.token = token
e.vector.values.extend(embedding)
return e.SerializeToString()
# Provide embeddings for the first three words in the word map.
records_path = os.path.join(FLAGS.test_tmpdir, 'sstable-00000-of-00001')
writer = tf.python_io.TFRecordWriter(records_path)
writer.write(_TokenEmbedding('.', [1, 2]))
writer.write(_TokenEmbedding(',', [3, 4]))
writer.write(_TokenEmbedding('the', [5, 6]))
del writer
with self.test_session():
embeddings = gen_parser_ops.word_embedding_initializer(
vectors=records_path,
task_context=self._task_context).eval()
self.assertAllClose(
np.array([[1. / (1 + 4) ** .5, 2. / (1 + 4) ** .5],
[3. / (9 + 16) ** .5, 4. / (9 + 16) ** .5],
[5. / (25 + 36) ** .5, 6. / (25 + 36) ** .5]]),
embeddings[:3,])
if __name__ == '__main__':
googletest.main()
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "syntaxnet/registry.h"
namespace syntaxnet {
// Global list of all component registries.
RegistryMetadata *global_registry_list = NULL;
void RegistryMetadata::Register(RegistryMetadata *registry) {
registry->set_link(global_registry_list);
global_registry_list = registry;
}
} // namespace syntaxnet
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Registry for component registration. These classes can be used for creating
// registries of components conforming to the same interface. This is useful for
// making a component-based architecture where the specific implementation
// classes can be selected at runtime. There is support for both class-based and
// instance based registries.
//
// Example:
// function.h:
//
// class Function : public RegisterableInstance<Function> {
// public:
// virtual double Evaluate(double x) = 0;
// };
//
// #define REGISTER_FUNCTION(type, component)
// REGISTER_INSTANCE_COMPONENT(Function, type, component);
//
// function.cc:
//
// REGISTER_INSTANCE_REGISTRY("function", Function);
//
// class Cos : public Function {
// public:
// double Evaluate(double x) { return cos(x); }
// };
//
// class Exp : public Function {
// public:
// double Evaluate(double x) { return exp(x); }
// };
//
// REGISTER_FUNCTION("cos", Cos);
// REGISTER_FUNCTION("exp", Exp);
//
// Function *f = Function::Lookup("cos");
// double result = f->Evaluate(arg);
#ifndef $TARGETDIR_REGISTRY_H_
#define $TARGETDIR_REGISTRY_H_
#include <string.h>
#include <string>
#include <vector>
#include "syntaxnet/utils.h"
namespace syntaxnet {
// Component metadata with information about name, class, and code location.
class ComponentMetadata {
public:
ComponentMetadata(const char *name, const char *class_name, const char *file,
int line)
: name_(name),
class_name_(class_name),
file_(file),
line_(line),
link_(NULL) {}
// Returns component name.
const char *name() const { return name_; }
// Metadata objects can be linked in a list.
ComponentMetadata *link() const { return link_; }
void set_link(ComponentMetadata *link) { link_ = link; }
private:
// Component name.
const char *name_;
// Name of class for component.
const char *class_name_;
// Code file and location where the component was registered.
const char *file_;
int line_;
// Link to next metadata object in list.
ComponentMetadata *link_;
};
// The master registry contains all registered component registries. A registry
// is not registered in the master registry until the first component of that
// type is registered.
class RegistryMetadata : public ComponentMetadata {
public:
RegistryMetadata(const char *name, const char *class_name, const char *file,
int line, ComponentMetadata **components)
: ComponentMetadata(name, class_name, file, line),
components_(components) {}
// Registers a component registry in the master registry.
static void Register(RegistryMetadata *registry);
private:
// Location of list of components in registry.
ComponentMetadata **components_;
};
// Registry for components. An object can be registered with a type name in the
// registry. The named instances in the registry can be returned using the
// Lookup() method. The components in the registry are put into a linked list
// of components. It is important that the component registry can be statically
// initialized in order not to depend on initialization order.
template <class T>
struct ComponentRegistry {
typedef ComponentRegistry<T> Self;
// Component registration class.
class Registrar : public ComponentMetadata {
public:
// Registers new component by linking itself into the component list of
// the registry.
Registrar(Self *registry, const char *type, const char *class_name,
const char *file, int line, T *object)
: ComponentMetadata(type, class_name, file, line), object_(object) {
// Register registry in master registry if this is the first registered
// component of this type.
if (registry->components == NULL) {
RegistryMetadata::Register(new RegistryMetadata(
registry->name, registry->class_name, registry->file,
registry->line,
reinterpret_cast<ComponentMetadata **>(&registry->components)));
}
// Register component in registry.
set_link(registry->components);
registry->components = this;
}
// Returns component type.
const char *type() const { return name(); }
// Returns component object.
T *object() const { return object_; }
// Returns the next component in the component list.
Registrar *next() const { return static_cast<Registrar *>(link()); }
private:
// Component object.
T *object_;
};
// Finds registrar for named component in registry.
const Registrar *GetComponent(const char *type) const {
Registrar *r = components;
while (r != NULL && strcmp(type, r->type()) != 0) r = r->next();
if (r == NULL) {
LOG(FATAL) << "Unknown " << name << " component: '" << type << "'.";
}
return r;
}
// Finds a named component in the registry.
T *Lookup(const char *type) const { return GetComponent(type)->object(); }
T *Lookup(const string &type) const { return Lookup(type.c_str()); }
// Textual description of the kind of components in the registry.
const char *name;
// Base class name of component type.
const char *class_name;
// File and line where the registry is defined.
const char *file;
int line;
// Linked list of registered components.
Registrar *components;
};
// Base class for registerable class-based components.
template <class T>
class RegisterableClass {
public:
// Factory function type.
typedef T *(Factory)();
// Registry type.
typedef ComponentRegistry<Factory> Registry;
// Creates a new component instance.
static T *Create(const string &type) { return registry()->Lookup(type)(); }
// Returns registry for class.
static Registry *registry() { return &registry_; }
private:
// Registry for class.
static Registry registry_;
};
// Base class for registerable instance-based components.
template <class T>
class RegisterableInstance {
public:
// Registry type.
typedef ComponentRegistry<T> Registry;
private:
// Registry for class.
static Registry registry_;
};
#define REGISTER_CLASS_COMPONENT(base, type, component) \
static base *__##component##__factory() { return new component; } \
static base::Registry::Registrar __##component##__##registrar( \
base::registry(), type, #component, __FILE__, __LINE__, \
__##component##__factory)
#define REGISTER_CLASS_REGISTRY(type, classname) \
template <> \
classname::Registry RegisterableClass<classname>::registry_ = { \
type, #classname, __FILE__, __LINE__, NULL}
#define REGISTER_INSTANCE_COMPONENT(base, type, component) \
static base::Registry::Registrar __##component##__##registrar( \
base::registry(), type, #component, __FILE__, __LINE__, new component)
#define REGISTER_INSTANCE_REGISTRY(type, classname) \
template <> \
classname::Registry RegisterableInstance<classname>::registry_ = { \
type, #classname, __FILE__, __LINE__, NULL}
} // namespace syntaxnet
#endif // $TARGETDIR_REGISTRY_H_
// Protocol buffer specification for document analysis.
syntax = "proto2";
package syntaxnet;
// A Sentence contains the raw text contents of a sentence, as well as an
// analysis.
message Sentence {
// Identifier for document.
optional string docid = 1;
// Raw text contents of the sentence.
optional string text = 2;
// Tokenization of the sentence.
repeated Token token = 3;
extensions 1000 to max;
}
// A document token marks a span of bytes in the document text as a token
// or word.
message Token {
// Token word form.
required string word = 1;
// Start position of token in text.
required int32 start = 2;
// End position of token in text. Gives index of last byte, not one past
// the last byte. If token came from lexer, excludes any trailing HTML tags.
required int32 end = 3;
// Head of this token in the dependency tree: the id of the token which has an
// arc going to this one. If it is the root token of a sentence, then it is
// set to -1.
optional int32 head = 4 [default = -1];
// Part-of-speech tag for token.
optional string tag = 5;
// Coarse-grained word category for token.
optional string category = 6;
// Label for dependency relation between this token and its head.
optional string label = 7;
// Break level for tokens that indicates how it was separated from the
// previous token in the text.
enum BreakLevel {
NO_BREAK = 0; // No separation between tokens.
SPACE_BREAK = 1; // Tokens separated by space.
LINE_BREAK = 2; // Tokens separated by line break.
SENTENCE_BREAK = 3; // Tokens separated by sentence break.
}
optional BreakLevel break_level = 8 [default = SPACE_BREAK];
extensions 1000 to max;
}
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "syntaxnet/sentence_batch.h"
#include <memory>
#include <string>
#include <vector>
#include "syntaxnet/task_context.h"
namespace syntaxnet {
void SentenceBatch::Init(TaskContext *context) {
reader_.reset(new TextReader(*context->GetInput(input_name_)));
size_ = 0;
}
bool SentenceBatch::AdvanceSentence(int index) {
if (sentences_[index] == nullptr) ++size_;
sentences_[index].reset();
std::unique_ptr<Sentence> sentence(reader_->Read());
if (sentence == nullptr) {
--size_;
return false;
}
// Preprocess the new sentence for the parser state.
sentences_[index] = std::move(sentence);
return true;
}
} // namespace syntaxnet
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef $TARGETDIR_SENTENCE_BATCH_H_
#define $TARGETDIR_SENTENCE_BATCH_H_
#include <memory>
#include <string>
#include <vector>
#include "syntaxnet/embedding_feature_extractor.h"
#include "syntaxnet/feature_extractor.h"
#include "syntaxnet/parser_state.h"
#include "syntaxnet/parser_transitions.h"
#include "syntaxnet/sentence.pb.h"
#include "syntaxnet/sparse.pb.h"
#include "syntaxnet/task_context.h"
#include "syntaxnet/task_spec.pb.h"
#include "syntaxnet/term_frequency_map.h"
namespace syntaxnet {
// Helper class to manage generating batches of preprocessed ParserState objects
// by reading in multiple sentences in parallel.
class SentenceBatch {
public:
SentenceBatch(int batch_size, string input_name)
: batch_size_(batch_size),
input_name_(input_name),
sentences_(batch_size) {}
// Initializes all resources and opens the corpus file.
void Init(TaskContext *context);
// Advances the index'th sentence in the batch to the next sentence. This will
// create and preprocess a new ParserState for that element. Returns false if
// EOF is reached (if EOF, also sets the state to be nullptr.)
bool AdvanceSentence(int index);
// Rewinds the corpus reader.
void Rewind() { reader_->Reset(); }
int size() const { return size_; }
Sentence *sentence(int index) { return sentences_[index].get(); }
private:
// Running tally of non-nullptr states in the batch.
int size_;
// Maximum number of states in the batch.
int batch_size_;
// Input to read from the TaskContext.
string input_name_;
// Reader for the corpus.
std::unique_ptr<TextReader> reader_;
// Batch: Sentence objects.
std::vector<std::unique_ptr<Sentence>> sentences_;
};
} // namespace syntaxnet
#endif // $TARGETDIR_SENTENCE_BATCH_H_
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "syntaxnet/sentence_features.h"
#include "syntaxnet/registry.h"
#include "util/utf8/unicodetext.h"
namespace syntaxnet {
TermFrequencyMapFeature::~TermFrequencyMapFeature() {
if (term_map_ != nullptr) {
SharedStore::Release(term_map_);
term_map_ = nullptr;
}
}
void TermFrequencyMapFeature::Setup(TaskContext *context) {
TokenLookupFeature::Setup(context);
context->GetInput(input_name_, "text", "");
}
void TermFrequencyMapFeature::Init(TaskContext *context) {
min_freq_ = GetIntParameter("min-freq", 0);
max_num_terms_ = GetIntParameter("max-num-terms", 0);
file_name_ = context->InputFile(*context->GetInput(input_name_));
term_map_ = SharedStoreUtils::GetWithDefaultName<TermFrequencyMap>(
file_name_, min_freq_, max_num_terms_);
TokenLookupFeature::Init(context);
}
string TermFrequencyMapFeature::GetFeatureValueName(FeatureValue value) const {
if (value == UnknownValue()) return "<UNKNOWN>";
if (value >= 0 && value < (NumValues() - 1)) {
return term_map_->GetTerm(value);
}
LOG(ERROR) << "Invalid feature value: " << value;
return "<INVALID>";
}
string TermFrequencyMapFeature::WorkspaceName() const {
return SharedStoreUtils::CreateDefaultName("term-frequency-map", input_name_,
min_freq_, max_num_terms_);
}
string Hyphen::GetFeatureValueName(FeatureValue value) const {
switch (value) {
case NO_HYPHEN:
return "NO_HYPHEN";
case HAS_HYPHEN:
return "HAS_HYPHEN";
}
return "<INVALID>";
}
FeatureValue Hyphen::ComputeValue(const Token &token) const {
const string &word = token.word();
return (word.find('-') < word.length() ? HAS_HYPHEN : NO_HYPHEN);
}
string Digit::GetFeatureValueName(FeatureValue value) const {
switch (value) {
case NO_DIGIT:
return "NO_DIGIT";
case SOME_DIGIT:
return "SOME_DIGIT";
case ALL_DIGIT:
return "ALL_DIGIT";
}
return "<INVALID>";
}
FeatureValue Digit::ComputeValue(const Token &token) const {
const string &word = token.word();
bool has_digit = isdigit(word[0]);
bool all_digit = has_digit;
for (size_t i = 1; i < word.length(); ++i) {
bool char_is_digit = isdigit(word[i]);
all_digit = all_digit && char_is_digit;
has_digit = has_digit || char_is_digit;
if (!all_digit && has_digit) return SOME_DIGIT;
}
if (!all_digit) return NO_DIGIT;
return ALL_DIGIT;
}
AffixTableFeature::AffixTableFeature(AffixTable::Type type)
: type_(type) {
if (type == AffixTable::PREFIX) {
input_name_ = "prefix-table";
} else {
input_name_ = "suffix-table";
}
}
AffixTableFeature::~AffixTableFeature() {
SharedStore::Release(affix_table_);
affix_table_ = nullptr;
}
string AffixTableFeature::WorkspaceName() const {
return SharedStoreUtils::CreateDefaultName(
"affix-table", input_name_, type_, affix_length_);
}
// Utility function to create a new affix table without changing constructors,
// to be called by the SharedStore.
static AffixTable *CreateAffixTable(const string &filename,
AffixTable::Type type) {
AffixTable *affix_table = new AffixTable(type, 1);
tensorflow::RandomAccessFile *file;
TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(filename, &file));
ProtoRecordReader reader(file);
affix_table->Read(&reader);
return affix_table;
}
void AffixTableFeature::Setup(TaskContext *context) {
context->GetInput(input_name_, "recordio", "affix-table");
affix_length_ = GetIntParameter("length", 0);
CHECK_GE(affix_length_, 0)
<< "Length must be specified for affix preprocessor.";
TokenLookupFeature::Setup(context);
}
void AffixTableFeature::Init(TaskContext *context) {
string filename = context->InputFile(*context->GetInput(input_name_));
// Get the shared AffixTable object.
std::function<AffixTable *()> closure =
std::bind(CreateAffixTable, filename, type_);
affix_table_ = SharedStore::ClosureGetOrDie(filename, &closure);
CHECK_GE(affix_table_->max_length(), affix_length_)
<< "Affixes of length " << affix_length_ << " needed, but the affix "
<<"table only provides affixes of length <= "
<< affix_table_->max_length() << ".";
TokenLookupFeature::Init(context);
}
FeatureValue AffixTableFeature::ComputeValue(const Token &token) const {
const string &word = token.word();
UnicodeText text;
text.PointToUTF8(word.c_str(), word.size());
if (affix_length_ > text.size()) return UnknownValue();
UnicodeText::const_iterator start, end;
if (type_ == AffixTable::PREFIX) {
start = end = text.begin();
for (int i = 0; i < affix_length_; ++i) ++end;
} else {
start = end = text.end();
for (int i = 0; i < affix_length_; ++i) --start;
}
string affix(start.utf8_data(), end.utf8_data() - start.utf8_data());
int affix_id = affix_table_->AffixId(affix);
return affix_id == -1 ? UnknownValue() : affix_id;
}
string AffixTableFeature::GetFeatureValueName(FeatureValue value) const {
if (value == UnknownValue()) return "<UNKNOWN>";
if (value >= 0 && value < UnknownValue()) {
return affix_table_->AffixForm(value);
}
LOG(ERROR) << "Invalid feature value: " << value;
return "<INVALID>";
}
// Registry for the Sentence + token index feature functions.
REGISTER_CLASS_REGISTRY("sentence+index feature function", SentenceFeature);
// Register the features defined in the header.
REGISTER_SENTENCE_IDX_FEATURE("word", Word);
REGISTER_SENTENCE_IDX_FEATURE("lcword", LowercaseWord);
REGISTER_SENTENCE_IDX_FEATURE("tag", Tag);
REGISTER_SENTENCE_IDX_FEATURE("offset", Offset);
REGISTER_SENTENCE_IDX_FEATURE("hyphen", Hyphen);
REGISTER_SENTENCE_IDX_FEATURE("digit", Digit);
REGISTER_SENTENCE_IDX_FEATURE("prefix", PrefixFeature);
REGISTER_SENTENCE_IDX_FEATURE("suffix", SuffixFeature);
} // namespace syntaxnet
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Features that operate on Sentence objects. Most features are defined
// in this header so they may be re-used via composition into other more
// advanced feature classes.
#ifndef $TARGETDIR_SENTENCE_FEATURES_H_
#define $TARGETDIR_SENTENCE_FEATURES_H_
#include "syntaxnet/affix.h"
#include "syntaxnet/feature_extractor.h"
#include "syntaxnet/feature_types.h"
#include "syntaxnet/shared_store.h"
#include "syntaxnet/task_context.h"
#include "syntaxnet/workspace.h"
namespace syntaxnet {
// Feature function for any component that processes Sentences, whose
// focus is a token index into the sentence.
typedef FeatureFunction<Sentence, int> SentenceFeature;
// Alias for Locator type features that take (Sentence, int) signatures
// and call other (Sentence, int) features.
template <class DER>
using Locator = FeatureLocator<DER, Sentence, int>;
class TokenLookupFeature : public SentenceFeature {
public:
void Init(TaskContext *context) override {
set_feature_type(new ResourceBasedFeatureType<TokenLookupFeature>(
name(), this, {{NumValues(), "<OUTSIDE>"}}));
}
// Given a position in a sentence and workspaces, looks up the corresponding
// feature value. The index is relative to the start of the sentence.
virtual FeatureValue ComputeValue(const Token &token) const = 0;
// Number of unique values.
virtual int64 NumValues() const = 0;
// Convert the numeric value of the feature to a human readable string.
virtual string GetFeatureValueName(FeatureValue value) const = 0;
// Name of the shared workspace.
virtual string WorkspaceName() const = 0;
// Runs ComputeValue for each token in the sentence.
void Preprocess(WorkspaceSet *workspaces,
Sentence *sentence) const override {
if (workspaces->Has<VectorIntWorkspace>(workspace_)) return;
VectorIntWorkspace *workspace = new VectorIntWorkspace(
sentence->token_size());
for (int i = 0; i < sentence->token_size(); ++i) {
const int value = ComputeValue(sentence->token(i));
workspace->set_element(i, value);
}
workspaces->Set<VectorIntWorkspace>(workspace_, workspace);
}
// Requests a vector of int's to store in the workspace registry.
void RequestWorkspaces(WorkspaceRegistry *registry) override {
workspace_ = registry->Request<VectorIntWorkspace>(WorkspaceName());
}
// Returns the precomputed value, or NumValues() for features outside
// the sentence.
FeatureValue Compute(const WorkspaceSet &workspaces,
const Sentence &sentence, int focus,
const FeatureVector *result) const override {
if (focus < 0 || focus >= sentence.token_size()) return NumValues();
return workspaces.Get<VectorIntWorkspace>(workspace_).element(focus);
}
private:
int workspace_;
};
// Lookup feature that uses a TermFrequencyMap to store a string->int mapping.
class TermFrequencyMapFeature : public TokenLookupFeature {
public:
explicit TermFrequencyMapFeature(const string &input_name)
: input_name_(input_name), min_freq_(0), max_num_terms_(0) {}
~TermFrequencyMapFeature() override;
// Requests the input map as a resource.
void Setup(TaskContext *context) override;
// Loads the input map into memory (using SharedStore to avoid redundancy.)
void Init(TaskContext *context) override;
// Number of unique values.
virtual int64 NumValues() const { return term_map_->Size() + 1; }
// Special value for strings not in the map.
FeatureValue UnknownValue() const { return term_map_->Size(); }
// Uses the TermFrequencyMap to lookup the string associated with a value.
string GetFeatureValueName(FeatureValue value) const override;
// Name of the shared workspace.
string WorkspaceName() const override;
protected:
const TermFrequencyMap &term_map() const { return *term_map_; }
private:
// Shortcut pointer to shared map. Not owned.
const TermFrequencyMap *term_map_ = nullptr;
// Name of the input for the term map.
string input_name_;
// Filename of the underlying resource.
string file_name_;
// Minimum frequency for term map.
int min_freq_;
// Maximum number of terms for term map.
int max_num_terms_;
};
class Word : public TermFrequencyMapFeature {
public:
Word() : TermFrequencyMapFeature("word-map") {}
FeatureValue ComputeValue(const Token &token) const override {
string form = token.word();
return term_map().LookupIndex(form, UnknownValue());
}
};
class LowercaseWord : public TermFrequencyMapFeature {
public:
LowercaseWord() : TermFrequencyMapFeature("lc-word-map") {}
FeatureValue ComputeValue(const Token &token) const override {
const string lcword = utils::Lowercase(token.word());
return term_map().LookupIndex(lcword, UnknownValue());
}
};
class Tag : public TermFrequencyMapFeature {
public:
Tag() : TermFrequencyMapFeature("tag-map") {}
FeatureValue ComputeValue(const Token &token) const override {
return term_map().LookupIndex(token.tag(), UnknownValue());
}
};
class Label : public TermFrequencyMapFeature {
public:
Label() : TermFrequencyMapFeature("label-map") {}
FeatureValue ComputeValue(const Token &token) const override {
return term_map().LookupIndex(token.label(), UnknownValue());
}
};
class LexicalCategoryFeature : public TokenLookupFeature {
public:
LexicalCategoryFeature(const string &name, int cardinality)
: name_(name), cardinality_(cardinality) {}
~LexicalCategoryFeature() override {}
FeatureValue NumValues() const override { return cardinality_; }
// Returns the identifier for the workspace for this preprocessor.
string WorkspaceName() const override {
return tensorflow::strings::StrCat(name_, ":", cardinality_);
}
private:
// Name of the category type.
const string name_;
// Number of values.
const int cardinality_;
};
// Preprocessor that computes whether a word has a hyphen or not.
class Hyphen : public LexicalCategoryFeature {
public:
// Enumeration of values.
enum Category {
NO_HYPHEN = 0,
HAS_HYPHEN = 1,
CARDINALITY = 2,
};
// Default constructor.
Hyphen() : LexicalCategoryFeature("hyphen", CARDINALITY) {}
// Returns a string representation of the enum value.
string GetFeatureValueName(FeatureValue value) const override;
// Returns the category value for the token.
FeatureValue ComputeValue(const Token &token) const override;
};
// Preprocessor that computes whether a word has a hyphen or not.
class Digit : public LexicalCategoryFeature {
public:
// Enumeration of values.
enum Category {
NO_DIGIT = 0,
SOME_DIGIT = 1,
ALL_DIGIT = 2,
CARDINALITY = 3,
};
// Default constructor.
Digit() : LexicalCategoryFeature("digit", CARDINALITY) {}
// Returns a string representation of the enum value.
string GetFeatureValueName(FeatureValue value) const override;
// Returns the category value for the token.
FeatureValue ComputeValue(const Token &token) const override;
};
// TokenLookupPreprocessor object to compute prefixes and suffixes of words. The
// AffixTable is stored in the SharedStore. This is very similar to the
// implementation of TermFrequencyMapPreprocessor, but using an AffixTable to
// perform the lookups. There are only two specializations, for prefixes and
// suffixes.
class AffixTableFeature : public TokenLookupFeature {
public:
// Explicit constructor to set the type of the table. This determines the
// requested input.
explicit AffixTableFeature(AffixTable::Type type);
~AffixTableFeature() override;
// Requests inputs for the affix table.
void Setup(TaskContext *context) override;
// Loads the affix table from the SharedStore.
void Init(TaskContext *context) override;
// The workspace name is specific to which affix length we are computing.
string WorkspaceName() const override;
// Returns the total number of affixes in the table, regardless of specified
// length.
FeatureValue NumValues() const override { return affix_table_->size() + 1; }
// Special value for strings not in the map.
FeatureValue UnknownValue() const { return affix_table_->size(); }
// Looks up the affix for a given word.
FeatureValue ComputeValue(const Token &token) const override;
// Returns the string associated with a value.
string GetFeatureValueName(FeatureValue value) const override;
private:
// Size parameter for the affix table.
int affix_length_;
// Name of the input for the table.
string input_name_;
// The type of the affix table.
const AffixTable::Type type_;
// Affix table used for indexing. This comes from the shared store, and is not
// owned directly.
const AffixTable *affix_table_ = nullptr;
};
// Specific instantiation for computing prefixes. This requires the input
// "prefix-table".
class PrefixFeature : public AffixTableFeature {
public:
PrefixFeature() : AffixTableFeature(AffixTable::PREFIX) {}
};
// Specific instantiation for computing suffixes. Requires the input
// "suffix-table."
class SuffixFeature : public AffixTableFeature {
public:
SuffixFeature() : AffixTableFeature(AffixTable::SUFFIX) {}
};
// Offset locator. Simple locator: just changes the focus by some offset.
class Offset : public Locator<Offset> {
public:
void UpdateArgs(const WorkspaceSet &workspaces,
const Sentence &sentence, int *focus) const {
*focus += argument();
}
};
typedef FeatureExtractor<Sentence, int> SentenceExtractor;
// Utility to register the sentence_instance::Feature functions.
#define REGISTER_SENTENCE_IDX_FEATURE(name, type) \
REGISTER_FEATURE_FUNCTION(SentenceFeature, name, type)
} // namespace syntaxnet
#endif // $TARGETDIR_SENTENCE_FEATURES_H_
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "syntaxnet/sentence_features.h"
#include <string>
#include <vector>
#include <gmock/gmock.h>
#include "syntaxnet/utils.h"
#include "syntaxnet/feature_extractor.h"
#include "syntaxnet/populate_test_inputs.h"
#include "syntaxnet/sentence.pb.h"
#include "syntaxnet/task_context.h"
#include "syntaxnet/task_spec.pb.h"
#include "syntaxnet/workspace.h"
using testing::UnorderedElementsAreArray;
namespace syntaxnet {
// A basic fixture for testing Features. Takes a string of a
// Sentence protobuf that is used as the test data in the constructor.
class SentenceFeaturesTest : public ::testing::Test {
protected:
explicit SentenceFeaturesTest(const string &prototxt)
: sentence_(ParseASCII(prototxt)),
creators_(PopulateTestInputs::Defaults(sentence_)) {}
static Sentence ParseASCII(const string &prototxt) {
Sentence document;
CHECK(TextFormat::ParseFromString(prototxt, &document));
return document;
}
// Prepares a new feature for extracting from the attached sentence,
// regenerating the TaskContext and all resources. Will automatically add
// anything in info_ field into the LexiFuse repository.
virtual void PrepareFeature(const string &fml) {
context_.mutable_spec()->mutable_input()->Clear();
context_.mutable_spec()->mutable_output()->Clear();
extractor_.reset(new SentenceExtractor());
extractor_->Parse(fml);
extractor_->Setup(&context_);
creators_.Populate(&context_);
extractor_->Init(&context_);
extractor_->RequestWorkspaces(&registry_);
workspaces_.Reset(registry_);
extractor_->Preprocess(&workspaces_, &sentence_);
}
// Returns the string representation of the prepared feature extracted at the
// given index.
virtual string ExtractFeature(int index) {
FeatureVector result;
extractor_->ExtractFeatures(workspaces_, sentence_, index,
&result);
return result.type(0)->GetFeatureValueName(result.value(0));
}
// Extracts a vector of string representations from evaluating the prepared
// set feature (returning multiple values) at the given index.
virtual vector<string> ExtractMultiFeature(int index) {
vector<string> values;
FeatureVector result;
extractor_->ExtractFeatures(workspaces_, sentence_, index,
&result);
for (int i = 0; i < result.size(); ++i) {
values.push_back(result.type(i)->GetFeatureValueName(result.value(i)));
}
return values;
}
Sentence sentence_;
WorkspaceSet workspaces_;
PopulateTestInputs::CreatorMap creators_;
TaskContext context_;
WorkspaceRegistry registry_;
std::unique_ptr<SentenceExtractor> extractor_;
};
// Test fixture for simple common features that operate on just a sentence.
class CommonSentenceFeaturesTest : public SentenceFeaturesTest {
protected:
CommonSentenceFeaturesTest()
: SentenceFeaturesTest(
"text: 'I saw a man with a telescope.' "
"token { word: 'I' start: 0 end: 0 tag: 'PRP' category: 'PRON'"
" head: 1 label: 'nsubj' break_level: NO_BREAK } "
"token { word: 'saw' start: 2 end: 4 tag: 'VBD' category: 'VERB'"
" label: 'ROOT' break_level: SPACE_BREAK } "
"token { word: 'a' start: 6 end: 6 tag: 'DT' category: 'DET'"
" head: 3 label: 'det' break_level: SPACE_BREAK } "
"token { word: 'man' start: 8 end: 10 tag: 'NN' category: 'NOUN'"
" head: 1 label: 'dobj' break_level: SPACE_BREAK } "
"token { word: 'with' start: 12 end: 15 tag: 'IN' category: 'ADP'"
" head: 1 label: 'prep' break_level: SPACE_BREAK } "
"token { word: 'a' start: 17 end: 17 tag: 'DT' category: 'DET'"
" head: 6 label: 'det' break_level: SPACE_BREAK } "
"token { word: 'telescope' start: 19 end: 27 tag: 'NN' category: "
"'NOUN'"
" head: 4 label: 'pobj' break_level: SPACE_BREAK } "
"token { word: '.' start: 28 end: 28 tag: '.' category: '.'"
" head: 1 label: 'p' break_level: NO_BREAK }") {}
};
TEST_F(CommonSentenceFeaturesTest, TagFeature) {
PrepareFeature("tag");
EXPECT_EQ("<OUTSIDE>", ExtractFeature(-1));
EXPECT_EQ("PRP", ExtractFeature(0));
EXPECT_EQ("VBD", ExtractFeature(1));
EXPECT_EQ("DT", ExtractFeature(2));
EXPECT_EQ("NN", ExtractFeature(3));
EXPECT_EQ("<OUTSIDE>", ExtractFeature(8));
}
TEST_F(CommonSentenceFeaturesTest, TagFeaturePassesArgs) {
PrepareFeature("tag(min-freq=5)"); // don't load any tags
EXPECT_EQ(ExtractFeature(-1), "<OUTSIDE>");
EXPECT_EQ(ExtractFeature(0), "<UNKNOWN>");
EXPECT_EQ(ExtractFeature(8), "<OUTSIDE>");
// Only 2 features: <UNKNOWN> and <OUTSIDE>.
EXPECT_EQ(2, extractor_->feature_type(0)->GetDomainSize());
}
TEST_F(CommonSentenceFeaturesTest, OffsetPlusTag) {
PrepareFeature("offset(-1).tag(min-freq=2)");
EXPECT_EQ("<OUTSIDE>", ExtractFeature(-1));
EXPECT_EQ("<OUTSIDE>", ExtractFeature(0));
EXPECT_EQ("<UNKNOWN>", ExtractFeature(1));
EXPECT_EQ("<UNKNOWN>", ExtractFeature(2));
EXPECT_EQ("DT", ExtractFeature(3)); // DT, NN are the only freq tags
EXPECT_EQ("NN", ExtractFeature(4));
EXPECT_EQ("<UNKNOWN>", ExtractFeature(5));
EXPECT_EQ("DT", ExtractFeature(6));
EXPECT_EQ("NN", ExtractFeature(7));
EXPECT_EQ("<UNKNOWN>", ExtractFeature(8));
EXPECT_EQ("<OUTSIDE>", ExtractFeature(9));
}
} // namespace syntaxnet
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "syntaxnet/shared_store.h"
#include <unordered_map>
#include "tensorflow/core/lib/strings/stringprintf.h"
namespace syntaxnet {
SharedStore::SharedObjectMap *SharedStore::shared_object_map_ =
new SharedObjectMap;
mutex SharedStore::shared_object_map_mutex_(tensorflow::LINKER_INITIALIZED);
SharedStore::SharedObjectMap *SharedStore::shared_object_map() {
return shared_object_map_;
}
bool SharedStore::Release(const void *object) {
if (object == nullptr) {
return true;
}
mutex_lock l(shared_object_map_mutex_);
for (SharedObjectMap::iterator it = shared_object_map()->begin();
it != shared_object_map()->end(); ++it) {
if (it->second.object == object) {
// Check the invariant that reference counts are positive. A violation
// likely implies memory corruption.
CHECK_GE(it->second.refcount, 1);
it->second.refcount--;
if (it->second.refcount == 0) {
it->second.delete_callback();
shared_object_map()->erase(it);
}
return true;
}
}
return false;
}
void SharedStore::Clear() {
mutex_lock l(shared_object_map_mutex_);
for (SharedObjectMap::iterator it = shared_object_map()->begin();
it != shared_object_map()->end(); ++it) {
it->second.delete_callback();
}
shared_object_map()->clear();
}
string SharedStoreUtils::CreateDefaultName() { return string(); }
string SharedStoreUtils::ToString(const string &input) {
return ToString(tensorflow::StringPiece(input));
}
string SharedStoreUtils::ToString(const char *input) {
return ToString(tensorflow::StringPiece(input));
}
string SharedStoreUtils::ToString(tensorflow::StringPiece input) {
return tensorflow::strings::StrCat("\"", utils::CEscape(input.ToString()),
"\"");
}
string SharedStoreUtils::ToString(bool input) {
return input ? "true" : "false";
}
string SharedStoreUtils::ToString(float input) {
return tensorflow::strings::Printf("%af", input);
}
string SharedStoreUtils::ToString(double input) {
return tensorflow::strings::Printf("%a", input);
}
} // namespace syntaxnet
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Utility for creating read-only objects once and sharing them across threads.
#ifndef $TARGETDIR_SHARED_STORE_H_
#define $TARGETDIR_SHARED_STORE_H_
#include <functional>
#include <string>
#include <typeindex>
#include <unordered_map>
#include <utility>
#include "syntaxnet/utils.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace syntaxnet {
class SharedStore {
public:
// Returns an existing object with type T and name 'name' if it exists, else
// creates one with "new T(args...)". Note: Objects will be indexed under
// their typeid + name, so names only have to be unique within a given type.
template <typename T, typename ...Args>
static const T *Get(const string &name,
Args &&...args); // NOLINT(build/c++11)
// Like Get(), but creates the object with "closure->Run()". If the closure
// returns null, we store a null in the SharedStore, but note that Release()
// cannot be used to remove it. This is because Release() finds the object
// by associative lookup, and there may be more than one null value, so we
// don't know which one to release. If the closure returns a duplicate value
// (one that is pointer-equal to an object already in the SharedStore),
// we disregard it and store null instead -- otherwise associative lookup
// would again fail (and the reference counts would be wrong).
template <typename T>
static const T *ClosureGet(const string &name, std::function<T *()> *closure);
// Like ClosureGet(), but check-fails if ClosureGet() would return null.
template <typename T>
static const T *ClosureGetOrDie(const string &name,
std::function<T *()> *closure);
// Release an object that was acquired by Get(). When its reference count
// hits 0, the object will be deleted. Returns true if the object was found.
// Does nothing and returns true if the object is null.
static bool Release(const void *object);
// Delete all objects in the shared store.
static void Clear();
private:
// A shared object.
struct SharedObject {
void *object;
std::function<void()> delete_callback;
int refcount;
SharedObject(void *o, std::function<void()> d)
: object(o), delete_callback(d), refcount(1) {}
};
// A map from keys to shared objects.
typedef std::unordered_map<string, SharedObject> SharedObjectMap;
// Return the shared object map.
static SharedObjectMap *shared_object_map();
// Return the string to use for indexing an object in the shared store.
template <typename T>
static string GetSharedKey(const string &name);
// Delete an object of type T.
template <typename T>
static void DeleteObject(T *object);
// Add an object to the shared object map. Return the object.
template <typename T>
static T *StoreObject(const string &key, T *object);
// Increment the reference count of an object in the map. Return the object.
template <typename T>
static T *IncrementRefCountOfObject(SharedObjectMap::iterator it);
// Map from keys to shared objects.
static SharedObjectMap *shared_object_map_;
static mutex shared_object_map_mutex_;
TF_DISALLOW_COPY_AND_ASSIGN(SharedStore);
};
template <typename T>
string SharedStore::GetSharedKey(const string &name) {
const std::type_index id = std::type_index(typeid(T));
return tensorflow::strings::StrCat(id.name(), "_", name);
}
template <typename T>
void SharedStore::DeleteObject(T *object) {
delete object;
}
template <typename T>
T *SharedStore::StoreObject(const string &key, T *object) {
std::function<void()> delete_cb =
std::bind(SharedStore::DeleteObject<T>, object);
SharedObject so(object, delete_cb);
shared_object_map()->insert(std::make_pair(key, so));
return object;
}
template <typename T>
T *SharedStore::IncrementRefCountOfObject(SharedObjectMap::iterator it) {
it->second.refcount++;
return static_cast<T *>(it->second.object);
}
template <typename T, typename ...Args>
const T *SharedStore::Get(const string &name,
Args &&...args) { // NOLINT(build/c++11)
mutex_lock l(shared_object_map_mutex_);
const string key = GetSharedKey<T>(name);
SharedObjectMap::iterator it = shared_object_map()->find(key);
return (it == shared_object_map()->end()) ?
StoreObject<T>(key, new T(std::forward<Args>(args)...)) :
IncrementRefCountOfObject<T>(it);
}
template <typename T>
const T *SharedStore::ClosureGet(const string &name,
std::function<T *()> *closure) {
mutex_lock l(shared_object_map_mutex_);
const string key = GetSharedKey<T>(name);
SharedObjectMap::iterator it = shared_object_map()->find(key);
if (it == shared_object_map()->end()) {
// Creates a new object by calling the closure.
T *object = (*closure)();
if (object == nullptr) {
LOG(ERROR) << "Closure returned a null pointer";
} else {
for (SharedObjectMap::iterator it = shared_object_map()->begin();
it != shared_object_map()->end(); ++it) {
if (it->second.object == object) {
LOG(ERROR)
<< "Closure returned duplicate pointer: "
<< "keys " << it->first << " and " << key;
// Not a memory leak to discard pointer, since we have another copy.
object = nullptr;
break;
}
}
}
return StoreObject<T>(key, object);
} else {
return IncrementRefCountOfObject<T>(it);
}
}
template <typename T>
const T *SharedStore::ClosureGetOrDie(const string &name,
std::function<T *()> *closure) {
const T *object = ClosureGet<T>(name, closure);
CHECK(object != nullptr);
return object;
}
// A collection of utility functions for working with the shared store.
class SharedStoreUtils {
public:
// Returns a shared object registered using a default name that is created
// from the constructor args.
//
// NB: This function does not guarantee a one-to-one relationship between
// sets of constructor args and names. See warnings on CreateDefaultName().
// It is the caller's responsibility to ensure that the args provided will
// result in unique names.
template <class T, class... Args>
static const T *GetWithDefaultName(Args &&... args) { // NOLINT(build/c++11)
return SharedStore::Get<T>(CreateDefaultName(std::forward<Args>(args)...),
std::forward<Args>(args)...);
}
// Returns a string name representing the args. Implemented via a pair of
// overloaded functions to achieve compile-time recursion.
//
// WARNING: It is possible for instances of different types to have the same
// string representation. For example,
//
// CreateDefaultName(1) == CreateDefaultName(1ULL)
//
template <class First, class... Rest>
static string CreateDefaultName(First &&first,
Rest &&... rest) { // NOLINT(build/c++11)
return tensorflow::strings::StrCat(
ToString<First>(std::forward<First>(first)), ",",
CreateDefaultName(std::forward<Rest>(rest)...));
}
static string CreateDefaultName();
private:
// Returns a string representing the input. The generic implementation uses
// StrCat(), and overloads are provided for selected types.
template <class T>
static string ToString(T input) {
return tensorflow::strings::StrCat(input);
}
static string ToString(const string &input);
static string ToString(const char *input);
static string ToString(tensorflow::StringPiece input);
static string ToString(bool input);
static string ToString(float input);
static string ToString(double input);
TF_DISALLOW_COPY_AND_ASSIGN(SharedStoreUtils);
};
} // namespace syntaxnet
#endif // $TARGETDIR_SHARED_STORE_H_
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "syntaxnet/shared_store.h"
#include <string>
#include <gmock/gmock.h>
#include "syntaxnet/utils.h"
#include "tensorflow/core/lib/core/threadpool.h"
using ::testing::_;
namespace syntaxnet {
struct NoArgs {
NoArgs() {
LOG(INFO) << "Calling NoArgs()";
}
};
struct OneArg {
string name;
explicit OneArg(const string &n) : name(n) {
LOG(INFO) << "Calling OneArg(" << name << ")";
}
};
struct TwoArgs {
string name;
int age;
TwoArgs(const string &n, int a) : name(n), age(a) {
LOG(INFO) << "Calling TwoArgs(" << name << ", " << age << ")";
}
};
struct Slow {
string lengthy;
Slow() {
LOG(INFO) << "Calling Slow()";
lengthy.assign(50 << 20, 'L'); // 50MB of the letter 'L'
}
};
struct CountCalls {
CountCalls() {
LOG(INFO) << "Calling CountCalls()";
++constructor_calls;
}
~CountCalls() {
LOG(INFO) << "Calling ~CountCalls()";
++destructor_calls;
}
static void Reset() {
constructor_calls = 0;
destructor_calls = 0;
}
static int constructor_calls;
static int destructor_calls;
};
int CountCalls::constructor_calls = 0;
int CountCalls::destructor_calls = 0;
class PointerSet {
public:
PointerSet() { }
void Add(const void *p) {
mutex_lock l(mu_);
pointers_.insert(p);
}
int size() {
mutex_lock l(mu_);
return pointers_.size();
}
private:
mutex mu_;
unordered_set<const void *> pointers_;
};
class SharedStoreTest : public testing::Test {
protected:
~SharedStoreTest() {
// Clear the shared store after each test, otherwise objects created
// in one test may interfere with other tests.
SharedStore::Clear();
}
};
// Verify that we can call constructors with varying numbers and types of args.
TEST_F(SharedStoreTest, ConstructorArgs) {
SharedStore::Get<NoArgs>("no args");
SharedStore::Get<OneArg>("one arg", "Fred");
SharedStore::Get<TwoArgs>("two args", "Pebbles", 2);
}
// Verify that an object with a given key is created only once.
TEST_F(SharedStoreTest, Shared) {
const NoArgs *ob1 = SharedStore::Get<NoArgs>("first");
const NoArgs *ob2 = SharedStore::Get<NoArgs>("second");
const NoArgs *ob3 = SharedStore::Get<NoArgs>("first");
EXPECT_EQ(ob1, ob3);
EXPECT_NE(ob1, ob2);
EXPECT_NE(ob2, ob3);
}
// Verify that objects with the same name but different types do not collide.
TEST_F(SharedStoreTest, DifferentTypes) {
const NoArgs *ob1 = SharedStore::Get<NoArgs>("same");
const OneArg *ob2 = SharedStore::Get<OneArg>("same", "foo");
const TwoArgs *ob3 = SharedStore::Get<TwoArgs>("same", "bar", 5);
EXPECT_NE(static_cast<const void *>(ob1), static_cast<const void *>(ob2));
EXPECT_NE(static_cast<const void *>(ob1), static_cast<const void *>(ob3));
EXPECT_NE(static_cast<const void *>(ob2), static_cast<const void *>(ob3));
}
// Factory method to make a OneArg.
OneArg *MakeOneArg(const string &n) {
return new OneArg(n);
}
TEST_F(SharedStoreTest, ClosureGet) {
std::function<OneArg *()> closure1 = std::bind(MakeOneArg, "Al");
std::function<OneArg *()> closure2 = std::bind(MakeOneArg, "Al");
const OneArg *ob1 = SharedStore::ClosureGet("first", &closure1);
const OneArg *ob2 = SharedStore::ClosureGet("first", &closure2);
EXPECT_EQ("Al", ob1->name);
EXPECT_EQ(ob1, ob2);
}
TEST_F(SharedStoreTest, PermanentCallback) {
std::function<OneArg *()> closure = std::bind(MakeOneArg, "Al");
const OneArg *ob1 = SharedStore::ClosureGet("first", &closure);
const OneArg *ob2 = SharedStore::ClosureGet("first", &closure);
EXPECT_EQ("Al", ob1->name);
EXPECT_EQ(ob1, ob2);
}
// Factory method to "make" a NoArgs by simply returning an input pointer.
NoArgs *BogusMakeNoArgs(NoArgs *ob) {
return ob;
}
// Create a CountCalls object, pretend it failed, and return null.
CountCalls *MakeFailedCountCalls() {
CountCalls *ob = new CountCalls;
delete ob;
return nullptr;
}
// Verify that ClosureGet() only calls the closure for a given key once,
// even if the closure fails.
TEST_F(SharedStoreTest, FailedClosureGet) {
CountCalls::Reset();
std::function<CountCalls *()> closure1(MakeFailedCountCalls);
std::function<CountCalls *()> closure2(MakeFailedCountCalls);
const CountCalls *ob1 = SharedStore::ClosureGet("first", &closure1);
const CountCalls *ob2 = SharedStore::ClosureGet("first", &closure2);
EXPECT_EQ(nullptr, ob1);
EXPECT_EQ(nullptr, ob2);
EXPECT_EQ(1, CountCalls::constructor_calls);
}
typedef SharedStoreTest SharedStoreDeathTest;
TEST_F(SharedStoreDeathTest, ClosureGetOrDie) {
NoArgs *empty = nullptr;
std::function<NoArgs *()> closure = std::bind(BogusMakeNoArgs, empty);
EXPECT_DEATH(SharedStore::ClosureGetOrDie("first", &closure), "nullptr");
}
TEST_F(SharedStoreTest, Release) {
const OneArg *ob1 = SharedStore::Get<OneArg>("first", "Fred");
const OneArg *ob2 = SharedStore::Get<OneArg>("first", "Fred");
EXPECT_EQ(ob1, ob2);
EXPECT_TRUE(SharedStore::Release(ob1)); // now refcount = 1
EXPECT_TRUE(SharedStore::Release(ob1)); // now object is deleted
EXPECT_FALSE(SharedStore::Release(ob1)); // now object is not found
EXPECT_TRUE(SharedStore::Release(nullptr)); // release(nullptr) returns true
}
TEST_F(SharedStoreTest, Clear) {
CountCalls::Reset();
SharedStore::Get<CountCalls>("first");
SharedStore::Get<CountCalls>("second");
SharedStore::Get<CountCalls>("first");
// Test that the constructor and destructor are each called exactly once
// for each key in the shared store.
SharedStore::Clear();
EXPECT_EQ(2, CountCalls::constructor_calls);
EXPECT_EQ(2, CountCalls::destructor_calls);
}
void GetSharedObject(PointerSet *ps) {
// Gets a shared object whose constructor takes a long time.
const Slow *ob = SharedStore::Get<Slow>("first");
// Collects the pointer we got. Later, we'll check whether SharedStore
// mistakenly called the constructor more than once.
ps->Add(static_cast<const void *>(ob));
}
// If multiple parallel threads all access an object with the same key,
// only one object is created.
TEST_F(SharedStoreTest, ThreadSafety) {
const int kNumThreads = 20;
tensorflow::thread::ThreadPool *pool = new tensorflow::thread::ThreadPool(
tensorflow::Env::Default(), "ThreadSafetyPool", kNumThreads);
PointerSet ps;
for (int i = 0; i < kNumThreads; ++i) {
std::function<void()> closure = std::bind(GetSharedObject, &ps);
pool->Schedule(closure);
}
// Waits for closures to finish, then delete the pool.
delete pool;
// Expects only one object to have been created across all threads.
EXPECT_EQ(1, ps.size());
}
} // namespace syntaxnet
// Protocol for passing around sparse sets of features.
syntax = "proto2";
package syntaxnet;
// A sparse set of features.
//
// If using SparseStringToIdTransformer, description is required and id should
// be omitted; otherwise, id is required and description optional.
//
// id, weight, and description fields are all aligned if present (ie, any of
// these that are non-empty should have the same # items). If weight is omitted,
// 1.0 is used.
message SparseFeatures {
repeated uint64 id = 1;
repeated float weight = 2;
repeated string description = 3;
};
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Build structured parser models."""
import tensorflow as tf
from tensorflow.python.ops import control_flow_ops as cf
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import tensor_array_ops
from syntaxnet import graph_builder
from syntaxnet.ops import gen_parser_ops
tf.NoGradient('BeamParseReader')
tf.NoGradient('BeamParser')
tf.NoGradient('BeamParserOutput')
def AddCrossEntropy(batch_size, n):
"""Adds a cross entropy cost function."""
cross_entropies = []
def _Pass():
return tf.constant(0, dtype=tf.float32, shape=[1])
for beam_id in range(batch_size):
beam_gold_slot = tf.reshape(tf.slice(n['gold_slot'], [beam_id], [1]), [1])
def _ComputeCrossEntropy():
"""Adds ops to compute cross entropy of the gold path in a beam."""
# Requires a cast so that UnsortedSegmentSum, in the gradient,
# is happy with the type of its input 'segment_ids', which
# must be int32.
idx = tf.cast(
tf.reshape(
tf.where(tf.equal(n['beam_ids'], beam_id)), [-1]), tf.int32)
beam_scores = tf.reshape(tf.gather(n['all_path_scores'], idx), [1, -1])
num = tf.shape(idx)
return tf.nn.softmax_cross_entropy_with_logits(
beam_scores, tf.expand_dims(
tf.sparse_to_dense(beam_gold_slot, num, [1.], 0.), 0))
# The conditional here is needed to deal with the last few batches of the
# corpus which can contain -1 in beam_gold_slot for empty batch slots.
cross_entropies.append(cf.cond(
beam_gold_slot[0] >= 0, _ComputeCrossEntropy, _Pass))
return {'cross_entropy': tf.div(tf.add_n(cross_entropies), batch_size)}
class StructuredGraphBuilder(graph_builder.GreedyParser):
"""Extends the standard GreedyParser with a CRF objective using a beam.
The constructor takes two additional keyword arguments.
beam_size: the maximum size the beam can grow to.
max_steps: the maximum number of steps in any particular beam.
The model supports batch training with the batch_size argument to the
AddTraining method.
"""
def __init__(self, *args, **kwargs):
self._beam_size = kwargs.pop('beam_size', 10)
self._max_steps = kwargs.pop('max_steps', 25)
super(StructuredGraphBuilder, self).__init__(*args, **kwargs)
def _AddBeamReader(self,
task_context,
batch_size,
corpus_name,
until_all_final=False,
always_start_new_sentences=False):
"""Adds an op capable of reading sentences and parsing them with a beam."""
features, state, epochs = gen_parser_ops.beam_parse_reader(
task_context=task_context,
feature_size=self._feature_size,
beam_size=self._beam_size,
batch_size=batch_size,
corpus_name=corpus_name,
allow_feature_weights=self._allow_feature_weights,
arg_prefix=self._arg_prefix,
continue_until_all_final=until_all_final,
always_start_new_sentences=always_start_new_sentences)
return {'state': state, 'features': features, 'epochs': epochs}
def _BuildSequence(self,
batch_size,
max_steps,
features,
state,
use_average=False):
"""Adds a sequence of beam parsing steps."""
def Advance(state, step, scores_array, alive, alive_steps, *features):
scores = self._BuildNetwork(features,
return_average=use_average)['logits']
scores_array = scores_array.write(step, scores)
features, state, alive = (
gen_parser_ops.beam_parser(state, scores, self._feature_size))
return [state, step + 1, scores_array, alive, alive_steps + tf.cast(
alive, tf.int32)] + list(features)
# args: (state, step, scores_array, alive, alive_steps, *features)
def KeepGoing(*args):
return tf.logical_and(args[1] < max_steps, tf.reduce_any(args[3]))
step = tf.constant(0, tf.int32, [])
scores_array = tensor_array_ops.TensorArray(dtype=tf.float32,
size=0,
dynamic_size=True)
alive = tf.constant(True, tf.bool, [batch_size])
alive_steps = tf.constant(0, tf.int32, [batch_size])
t = tf.while_loop(
KeepGoing,
Advance,
[state, step, scores_array, alive, alive_steps] + list(features),
parallel_iterations=100)
# Link to the final nodes/values of ops that have passed through While:
return {'state': t[0],
'concat_scores': t[2].concat(),
'alive': t[3],
'alive_steps': t[4]}
def AddTraining(self,
task_context,
batch_size,
learning_rate=0.1,
decay_steps=4000,
momentum=None,
corpus_name='documents'):
with tf.name_scope('training'):
n = self.training
n['accumulated_alive_steps'] = self._AddVariable(
[batch_size], tf.int32, 'accumulated_alive_steps',
tf.zeros_initializer)
n.update(self._AddBeamReader(task_context, batch_size, corpus_name))
# This adds a required 'step' node too:
learning_rate = tf.constant(learning_rate, dtype=tf.float32)
n['learning_rate'] = self._AddLearningRate(learning_rate, decay_steps)
# Call BuildNetwork *only* to set up the params outside of the main loop.
self._BuildNetwork(list(n['features']))
n.update(self._BuildSequence(batch_size, self._max_steps, n['features'],
n['state']))
flat_concat_scores = tf.reshape(n['concat_scores'], [-1])
(indices_and_paths, beams_and_slots, n['gold_slot'], n[
'beam_path_scores']) = gen_parser_ops.beam_parser_output(n[
'state'])
n['indices'] = tf.reshape(tf.gather(indices_and_paths, [0]), [-1])
n['path_ids'] = tf.reshape(tf.gather(indices_and_paths, [1]), [-1])
n['all_path_scores'] = tf.sparse_segment_sum(
flat_concat_scores, n['indices'], n['path_ids'])
n['beam_ids'] = tf.reshape(tf.gather(beams_and_slots, [0]), [-1])
n.update(AddCrossEntropy(batch_size, n))
if self._only_train:
trainable_params = {k: v for k, v in self.params.iteritems()
if k in self._only_train}
else:
trainable_params = self.params
for p in trainable_params:
tf.logging.info('trainable_param: %s', p)
regularized_params = [
tf.nn.l2_loss(p) for k, p in trainable_params.iteritems()
if k.startswith('weights') or k.startswith('bias')]
l2_loss = 1e-4 * tf.add_n(regularized_params) if regularized_params else 0
n['cost'] = tf.add(n['cross_entropy'], l2_loss, name='cost')
n['gradients'] = tf.gradients(n['cost'], trainable_params.values())
with tf.control_dependencies([n['alive_steps']]):
update_accumulators = tf.group(
tf.assign_add(n['accumulated_alive_steps'], n['alive_steps']))
def ResetAccumulators():
return tf.assign(
n['accumulated_alive_steps'], tf.zeros([batch_size], tf.int32))
n['reset_accumulators_func'] = ResetAccumulators
optimizer = tf.train.MomentumOptimizer(n['learning_rate'],
momentum,
use_locking=self._use_locking)
train_op = optimizer.minimize(n['cost'],
var_list=trainable_params.values())
for param in trainable_params.values():
slot = optimizer.get_slot(param, 'momentum')
self.inits[slot.name] = state_ops.init_variable(slot,
tf.zeros_initializer)
self.variables[slot.name] = slot
def NumericalChecks():
return tf.group(*[
tf.check_numerics(param, message='Parameter is not finite.')
for param in trainable_params.values()
if param.dtype.base_dtype in [tf.float32, tf.float64]])
check_op = cf.cond(tf.equal(tf.mod(self.GetStep(), self._check_every), 0),
NumericalChecks, tf.no_op)
avg_update_op = tf.group(*self._averaging.values())
train_ops = [train_op]
if self._check_parameters:
train_ops.append(check_op)
if self._use_averaging:
train_ops.append(avg_update_op)
with tf.control_dependencies([update_accumulators]):
n['train_op'] = tf.group(*train_ops, name='train_op')
n['alive_steps'] = tf.identity(n['alive_steps'], name='alive_steps')
return n
def AddEvaluation(self,
task_context,
batch_size,
evaluation_max_steps=300,
corpus_name=None):
with tf.name_scope('evaluation'):
n = self.evaluation
n.update(self._AddBeamReader(task_context,
batch_size,
corpus_name,
until_all_final=True,
always_start_new_sentences=True))
self._BuildNetwork(
list(n['features']),
return_average=self._use_averaging)
n.update(self._BuildSequence(batch_size, evaluation_max_steps, n[
'features'], n['state'], use_average=self._use_averaging))
n['eval_metrics'], n['documents'] = (
gen_parser_ops.beam_eval_output(n['state']))
return n
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
load("@tf//google/protobuf:protobuf.bzl", "cc_proto_library")
load("@tf//google/protobuf:protobuf.bzl", "py_proto_library")
def if_cuda(a, b=[]):
return select({
"@tf//third_party/gpus/cuda:cuda_crosstool_condition": a,
"//conditions:default": b,
})
def tf_copts():
return (["-fno-exceptions", "-DEIGEN_AVOID_STL_ARRAY",] +
if_cuda(["-DGOOGLE_CUDA=1"]) +
select({"@tf//tensorflow:darwin": [],
"//conditions:default": ["-pthread"]}))
def tf_proto_library(name, srcs=[], has_services=False,
deps=[], visibility=None, testonly=0,
cc_api_version=2, go_api_version=2,
java_api_version=2,
py_api_version=2):
native.filegroup(name=name + "_proto_srcs",
srcs=srcs,
testonly=testonly,)
cc_proto_library(name=name,
srcs=srcs,
deps=deps,
cc_libs = ["@tf//google/protobuf:protobuf"],
protoc="@tf//google/protobuf:protoc",
default_runtime="@tf//google/protobuf:protobuf",
testonly=testonly,
visibility=visibility,)
def tf_proto_library_py(name, srcs=[], deps=[], visibility=None, testonly=0):
py_proto_library(name=name,
srcs=srcs,
srcs_version = "PY2AND3",
deps=deps,
default_runtime="@tf//google/protobuf:protobuf_python",
protoc="@tf//google/protobuf:protoc",
visibility=visibility,
testonly=testonly,)
# Given a list of "op_lib_names" (a list of files in the ops directory
# without their .cc extensions), generate a library for that file.
def tf_gen_op_libs(op_lib_names):
# Make library out of each op so it can also be used to generate wrappers
# for various languages.
for n in op_lib_names:
native.cc_library(name=n + "_op_lib",
copts=tf_copts(),
srcs=["ops/" + n + ".cc"],
deps=(["@tf//tensorflow/core:framework"]),
visibility=["//visibility:public"],
alwayslink=1,
linkstatic=1,)
# Invoke this rule in .../tensorflow/python to build the wrapper library.
def tf_gen_op_wrapper_py(name, out=None, hidden=[], visibility=None, deps=[],
require_shape_functions=False):
# Construct a cc_binary containing the specified ops.
tool_name = "gen_" + name + "_py_wrappers_cc"
if not deps:
deps = ["//tensorflow/core:" + name + "_op_lib"]
native.cc_binary(
name = tool_name,
linkopts = ["-lm"],
copts = tf_copts(),
linkstatic = 1, # Faster to link this one-time-use binary dynamically
deps = (["@tf//tensorflow/core:framework",
"@tf//tensorflow/python:python_op_gen_main"] + deps),
)
# Invoke the previous cc_binary to generate a python file.
if not out:
out = "ops/gen_" + name + ".py"
native.genrule(
name=name + "_pygenrule",
outs=[out],
tools=[tool_name],
cmd=("$(location " + tool_name + ") " + ",".join(hidden)
+ " " + ("1" if require_shape_functions else "0") + " > $@"))
# Make a py_library out of the generated python file.
native.py_library(name=name,
srcs=[out],
srcs_version="PY2AND3",
visibility=visibility,
deps=[
"@tf//tensorflow/python:framework_for_generated_wrappers",
],)
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Tagger transition system.
//
// This transition system has one type of actions:
// - The SHIFT action pushes the next input token to the stack and
// advances to the next input token, assigning a part-of-speech tag to the
// token that was shifted.
//
// The transition system operates with parser actions encoded as integers:
// - A SHIFT action is encoded as number starting from 0.
#include <string>
#include "syntaxnet/parser_state.h"
#include "syntaxnet/parser_transitions.h"
#include "syntaxnet/shared_store.h"
#include "syntaxnet/task_context.h"
#include "syntaxnet/term_frequency_map.h"
#include "syntaxnet/utils.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace syntaxnet {
class TaggerTransitionState : public ParserTransitionState {
public:
explicit TaggerTransitionState(const TermFrequencyMap *tag_map,
const TagToCategoryMap *tag_to_category)
: tag_map_(tag_map), tag_to_category_(tag_to_category) {}
explicit TaggerTransitionState(const TaggerTransitionState *state)
: TaggerTransitionState(state->tag_map_, state->tag_to_category_) {
tag_ = state->tag_;
gold_tag_ = state->gold_tag_;
}
// Clones the transition state by returning a new object.
ParserTransitionState *Clone() const override {
return new TaggerTransitionState(this);
}
// Reads gold tags for each token.
void Init(ParserState *state) {
tag_.resize(state->sentence().token_size(), -1);
gold_tag_.resize(state->sentence().token_size(), -1);
for (int pos = 0; pos < state->sentence().token_size(); ++pos) {
int tag = tag_map_->LookupIndex(state->GetToken(pos).tag(), -1);
gold_tag_[pos] = tag;
}
}
// Returns the tag assigned to a given token.
int Tag(int index) const {
DCHECK_GE(index, 0);
DCHECK_LT(index, tag_.size());
return index == -1 ? -1 : tag_[index];
}
// Sets this tag on the token at index.
void SetTag(int index, int tag) {
DCHECK_GE(index, 0);
DCHECK_LT(index, tag_.size());
tag_[index] = tag;
}
// Returns the gold tag for a given token.
int GoldTag(int index) const {
DCHECK_GE(index, -1);
DCHECK_LT(index, gold_tag_.size());
return index == -1 ? -1 : gold_tag_[index];
}
// Returns the string representation of a POS tag, or an empty string
// if the tag is invalid.
string TagAsString(int tag) const {
if (tag >= 0 && tag < tag_map_->Size()) {
return tag_map_->GetTerm(tag);
}
return "";
}
// Adds transition state specific annotations to the document.
void AddParseToDocument(const ParserState &state, bool rewrite_root_labels,
Sentence *sentence) const override {
for (size_t i = 0; i < tag_.size(); ++i) {
Token *token = sentence->mutable_token(i);
token->set_tag(TagAsString(Tag(i)));
token->set_category(tag_to_category_->GetCategory(token->tag()));
}
}
// Whether a parsed token should be considered correct for evaluation.
bool IsTokenCorrect(const ParserState &state, int index) const override {
return GoldTag(index) == Tag(index);
}
// Returns a human readable string representation of this state.
string ToString(const ParserState &state) const override {
string str;
for (int i = state.StackSize(); i > 0; --i) {
const string &word = state.GetToken(state.Stack(i - 1)).word();
if (i != state.StackSize() - 1) str.append(" ");
tensorflow::strings::StrAppend(
&str, word, "[", TagAsString(Tag(state.StackSize() - i)), "]");
}
for (int i = state.Next(); i < state.NumTokens(); ++i) {
tensorflow::strings::StrAppend(&str, " ", state.GetToken(i).word());
}
return str;
}
private:
// Currently assigned POS tags for each token in this sentence.
vector<int> tag_;
// Gold POS tags from the input document.
vector<int> gold_tag_;
// Tag map used for conversions between integer and string representations
// part of speech tags. Not owned.
const TermFrequencyMap *tag_map_ = nullptr;
// Tag to category map. Not owned.
const TagToCategoryMap *tag_to_category_ = nullptr;
TF_DISALLOW_COPY_AND_ASSIGN(TaggerTransitionState);
};
class TaggerTransitionSystem : public ParserTransitionSystem {
public:
~TaggerTransitionSystem() override { SharedStore::Release(tag_map_); }
// Determines tag map location.
void Setup(TaskContext *context) override {
input_tag_map_ = context->GetInput("tag-map", "text", "");
input_tag_to_category_ = context->GetInput("tag-to-category", "text", "");
}
// Reads tag map and tag to category map.
void Init(TaskContext *context) {
const string tag_map_path = TaskContext::InputFile(*input_tag_map_);
tag_map_ = SharedStoreUtils::GetWithDefaultName<TermFrequencyMap>(
tag_map_path, 0, 0);
const string tag_to_category_path =
TaskContext::InputFile(*input_tag_to_category_);
tag_to_category_ = SharedStoreUtils::GetWithDefaultName<TagToCategoryMap>(
tag_to_category_path);
}
// The SHIFT action uses the same value as the corresponding action type.
static ParserAction ShiftAction(int tag) { return tag; }
// Returns the number of action types.
int NumActionTypes() const override { return 1; }
// Returns the number of possible actions.
int NumActions(int num_labels) const override { return tag_map_->Size(); }
// The default action for a given state is assigning the most frequent tag.
ParserAction GetDefaultAction(const ParserState &state) const override {
return ShiftAction(0);
}
// Returns the next gold action for a given state according to the
// underlying annotated sentence.
ParserAction GetNextGoldAction(const ParserState &state) const override {
if (!state.EndOfInput()) {
return ShiftAction(TransitionState(state).GoldTag(state.Next()));
}
return ShiftAction(0);
}
// Checks if the action is allowed in a given parser state.
bool IsAllowedAction(ParserAction action,
const ParserState &state) const override {
return !state.EndOfInput();
}
// Makes a shift by pushing the next input token on the stack and moving to
// the next position.
void PerformActionWithoutHistory(ParserAction action,
ParserState *state) const override {
DCHECK(!state->EndOfInput());
if (!state->EndOfInput()) {
MutableTransitionState(state)->SetTag(state->Next(), action);
state->Push(state->Next());
state->Advance();
}
}
// We are in a final state when we reached the end of the input and the stack
// is empty.
bool IsFinalState(const ParserState &state) const override {
return state.EndOfInput();
}
// Returns a string representation of a parser action.
string ActionAsString(ParserAction action,
const ParserState &state) const override {
return tensorflow::strings::StrCat("SHIFT(", tag_map_->GetTerm(action),
")");
}
// No state is deterministic in this transition system.
bool IsDeterministicState(const ParserState &state) const override {
return false;
}
// Returns a new transition state to be used to enhance the parser state.
ParserTransitionState *NewTransitionState(bool training_mode) const override {
return new TaggerTransitionState(tag_map_, tag_to_category_);
}
// Downcasts the const ParserTransitionState in ParserState to a const
// TaggerTransitionState.
static const TaggerTransitionState &TransitionState(
const ParserState &state) {
return *static_cast<const TaggerTransitionState *>(
state.transition_state());
}
// Downcasts the ParserTransitionState in ParserState to an
// TaggerTransitionState.
static TaggerTransitionState *MutableTransitionState(ParserState *state) {
return static_cast<TaggerTransitionState *>(
state->mutable_transition_state());
}
// Input for the tag map. Not owned.
TaskInput *input_tag_map_ = nullptr;
// Tag map used for conversions between integer and string representations
// part of speech tags. Owned through SharedStore.
const TermFrequencyMap *tag_map_ = nullptr;
// Input for the tag to category map. Not owned.
TaskInput *input_tag_to_category_ = nullptr;
// Tag to category map. Owned through SharedStore.
const TagToCategoryMap *tag_to_category_ = nullptr;
};
REGISTER_TRANSITION_SYSTEM("tagger", TaggerTransitionSystem);
} // namespace syntaxnet
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <string>
#include "syntaxnet/utils.h"
#include "syntaxnet/parser_state.h"
#include "syntaxnet/parser_transitions.h"
#include "syntaxnet/populate_test_inputs.h"
#include "syntaxnet/sentence.pb.h"
#include "syntaxnet/task_context.h"
#include "syntaxnet/task_spec.pb.h"
#include "syntaxnet/term_frequency_map.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
class TaggerTransitionTest : public ::testing::Test {
public:
TaggerTransitionTest()
: transition_system_(ParserTransitionSystem::Create("tagger")) {}
protected:
// Creates a label map and a tag map for testing based on the given
// document and initializes the transition system appropriately.
void SetUpForDocument(const Sentence &document) {
input_label_map_ = context_.GetInput("label-map", "text", "");
input_label_map_ = context_.GetInput("tag-map", "text", "");
transition_system_->Setup(&context_);
PopulateTestInputs::Defaults(document).Populate(&context_);
label_map_.Load(TaskContext::InputFile(*input_label_map_),
0 /* minimum frequency */,
-1 /* maximum number of terms */);
transition_system_->Init(&context_);
}
// Creates a cloned state from a sentence in order to test that cloning
// works correctly for the new parser states.
ParserState *NewClonedState(Sentence *sentence) {
ParserState state(sentence, transition_system_->NewTransitionState(
true /* training mode */),
&label_map_);
return state.Clone();
}
// Performs gold transitions and check that the labels and heads recorded
// in the parser state match gold heads and labels.
void GoldParse(Sentence *sentence) {
ParserState *state = NewClonedState(sentence);
LOG(INFO) << "Initial parser state: " << state->ToString();
while (!transition_system_->IsFinalState(*state)) {
ParserAction action = transition_system_->GetNextGoldAction(*state);
EXPECT_TRUE(transition_system_->IsAllowedAction(action, *state));
LOG(INFO) << "Performing action: "
<< transition_system_->ActionAsString(action, *state);
transition_system_->PerformActionWithoutHistory(action, state);
LOG(INFO) << "Parser state: " << state->ToString();
}
delete state;
}
// Always takes the default action, and verifies that this leads to
// a final state through a sequence of allowed actions.
void DefaultParse(Sentence *sentence) {
ParserState *state = NewClonedState(sentence);
LOG(INFO) << "Initial parser state: " << state->ToString();
while (!transition_system_->IsFinalState(*state)) {
ParserAction action = transition_system_->GetDefaultAction(*state);
EXPECT_TRUE(transition_system_->IsAllowedAction(action, *state));
LOG(INFO) << "Performing action: "
<< transition_system_->ActionAsString(action, *state);
transition_system_->PerformActionWithoutHistory(action, state);
LOG(INFO) << "Parser state: " << state->ToString();
}
delete state;
}
TaskContext context_;
TaskInput *input_label_map_ = nullptr;
TermFrequencyMap label_map_;
std::unique_ptr<ParserTransitionSystem> transition_system_;
};
TEST_F(TaggerTransitionTest, SingleSentenceDocumentTest) {
string document_text;
Sentence document;
TF_CHECK_OK(ReadFileToString(
tensorflow::Env::Default(),
"syntaxnet/testdata/document",
&document_text));
LOG(INFO) << "see doc\n:" << document_text;
CHECK(TextFormat::ParseFromString(document_text, &document));
SetUpForDocument(document);
GoldParse(&document);
DefaultParse(&document);
}
} // namespace syntaxnet
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "syntaxnet/task_context.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/env.h"
namespace syntaxnet {
namespace {
const char *const kShardPrintFormat = "%05d";
} // namespace
TaskInput *TaskContext::GetInput(const string &name) {
// Return existing input if it exists.
for (int i = 0; i < spec_.input_size(); ++i) {
if (spec_.input(i).name() == name) return spec_.mutable_input(i);
}
// Create new input.
TaskInput *input = spec_.add_input();
input->set_name(name);
return input;
}
TaskInput *TaskContext::GetInput(const string &name, const string &file_format,
const string &record_format) {
TaskInput *input = GetInput(name);
if (!file_format.empty()) {
bool found = false;
for (int i = 0; i < input->file_format_size(); ++i) {
if (input->file_format(i) == file_format) found = true;
}
if (!found) input->add_file_format(file_format);
}
if (!record_format.empty()) {
bool found = false;
for (int i = 0; i < input->record_format_size(); ++i) {
if (input->record_format(i) == record_format) found = true;
}
if (!found) input->add_record_format(record_format);
}
return input;
}
void TaskContext::SetParameter(const string &name, const string &value) {
// If the parameter already exists update the value.
for (int i = 0; i < spec_.parameter_size(); ++i) {
if (spec_.parameter(i).name() == name) {
spec_.mutable_parameter(i)->set_value(value);
return;
}
}
// Add new parameter.
TaskSpec::Parameter *param = spec_.add_parameter();
param->set_name(name);
param->set_value(value);
}
string TaskContext::GetParameter(const string &name) const {
// First try to find parameter in task specification.
for (int i = 0; i < spec_.parameter_size(); ++i) {
if (spec_.parameter(i).name() == name) return spec_.parameter(i).value();
}
// Parameter not found, return empty string.
return "";
}
int TaskContext::GetIntParameter(const string &name) const {
string value = GetParameter(name);
return utils::ParseUsing<int>(value, 0, utils::ParseInt32);
}
int64 TaskContext::GetInt64Parameter(const string &name) const {
string value = GetParameter(name);
return utils::ParseUsing<int64>(value, 0ll, utils::ParseInt64);
}
bool TaskContext::GetBoolParameter(const string &name) const {
string value = GetParameter(name);
return value == "true";
}
double TaskContext::GetFloatParameter(const string &name) const {
string value = GetParameter(name);
return utils::ParseUsing<double>(value, .0, utils::ParseDouble);
}
string TaskContext::Get(const string &name, const char *defval) const {
// First try to find parameter in task specification.
for (int i = 0; i < spec_.parameter_size(); ++i) {
if (spec_.parameter(i).name() == name) return spec_.parameter(i).value();
}
// Parameter not found, return default value.
return defval;
}
string TaskContext::Get(const string &name, const string &defval) const {
return Get(name, defval.c_str());
}
int TaskContext::Get(const string &name, int defval) const {
string value = Get(name, "");
return utils::ParseUsing<int>(value, defval, utils::ParseInt32);
}
int64 TaskContext::Get(const string &name, int64 defval) const {
string value = Get(name, "");
return utils::ParseUsing<int64>(value, defval, utils::ParseInt64);
}
double TaskContext::Get(const string &name, double defval) const {
string value = Get(name, "");
return utils::ParseUsing<double>(value, defval, utils::ParseDouble);
}
bool TaskContext::Get(const string &name, bool defval) const {
string value = Get(name, "");
return value.empty() ? defval : value == "true";
}
string TaskContext::InputFile(const TaskInput &input) {
CHECK_EQ(input.part_size(), 1) << input.name();
return input.part(0).file_pattern();
}
bool TaskContext::Supports(const TaskInput &input, const string &file_format,
const string &record_format) {
// Check file format.
if (input.file_format_size() > 0) {
bool found = false;
for (int i = 0; i < input.file_format_size(); ++i) {
if (input.file_format(i) == file_format) {
found = true;
break;
}
}
if (!found) return false;
}
// Check record format.
if (input.record_format_size() > 0) {
bool found = false;
for (int i = 0; i < input.record_format_size(); ++i) {
if (input.record_format(i) == record_format) {
found = true;
break;
}
}
if (!found) return false;
}
return true;
}
} // namespace syntaxnet
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment