Commit d66941ac authored by Ivan Bogatyy's avatar Ivan Bogatyy Committed by calberti
Browse files

Sync w TF r0.12 & Bazel 0.4.3, internal updates (#953)

parent efa4a6cf
......@@ -117,7 +117,7 @@ class LexiconBuilder : public OpKernel {
tag_to_category.SetCategory(token.tag(), token.category());
// Add characters.
vector<tensorflow::StringPiece> char_sp;
std::vector<tensorflow::StringPiece> char_sp;
SegmenterUtils::GetUTF8Chars(word, &char_sp);
for (const auto &c : char_sp) {
const string c_str = c.ToString();
......
......@@ -132,10 +132,10 @@ class MorphologyTransitionState : public ParserTransitionState {
private:
// Currently assigned morphological analysis for each token in this sentence.
vector<int> tag_;
std::vector<int> tag_;
// Gold morphological analysis from the input document.
vector<int> gold_tag_;
std::vector<int> gold_tag_;
// Tag map used for conversions between integer and string representations
// part of speech tags. Not owned.
......
......@@ -69,7 +69,7 @@ void MorphologyLabelSet::Write(ProtoRecordWriter *writer) const {
}
string MorphologyLabelSet::StringForMatch(const TokenMorphology &morph) const {
vector<string> attributes;
std::vector<string> attributes;
for (const auto &a : morph.attribute()) {
attributes.push_back(
tensorflow::strings::StrCat(a.name(), kSeparator, a.value()));
......@@ -80,7 +80,7 @@ string MorphologyLabelSet::StringForMatch(const TokenMorphology &morph) const {
string FullLabelFeatureType::GetFeatureValueName(FeatureValue value) const {
const TokenMorphology &morph = label_set_->Lookup(value);
vector<string> attributes;
std::vector<string> attributes;
for (const auto &a : morph.attribute()) {
attributes.push_back(tensorflow::strings::StrCat(a.name(), ":", a.value()));
}
......
......@@ -70,7 +70,7 @@ class MorphologyLabelSet {
// defined as follows:
//
// a == b iff the set of attribute pairs (attribute, value) is identical.
vector<TokenMorphology> label_set_;
std::vector<TokenMorphology> label_set_;
// Because protocol buffer equality is complicated, we implement our own
// equality operator based on strings. This unordered_map allows us to do the
......
......@@ -24,11 +24,12 @@ limitations under the License.
namespace syntaxnet {
// Registry for the parser feature functions.
REGISTER_CLASS_REGISTRY("parser feature function", ParserFeatureFunction);
REGISTER_SYNTAXNET_CLASS_REGISTRY("parser feature function",
ParserFeatureFunction);
// Registry for the parser state + token index feature functions.
REGISTER_CLASS_REGISTRY("parser+index feature function",
ParserIndexFeatureFunction);
REGISTER_SYNTAXNET_CLASS_REGISTRY("parser+index feature function",
ParserIndexFeatureFunction);
RootFeatureType::RootFeatureType(const string &name,
const FeatureType &wrapped_type,
......@@ -228,4 +229,29 @@ class ParserTokenFeatureFunction : public NestedFeatureFunction<
REGISTER_PARSER_IDX_FEATURE_FUNCTION("token",
ParserTokenFeatureFunction);
// Parser feature that always fetches the focus (position) of the token.
class FocusFeatureFunction : public ParserIndexFeatureFunction {
public:
// Initializes the feature function.
void Init(TaskContext *context) override {
// Note: this feature can return up to N values, where N is the length of
// the input sentence. Here, we give the arbitrary number 100 since it
// is not used.
set_feature_type(new NumericFeatureType(name(), 100));
}
void Evaluate(const WorkspaceSet &workspaces, const ParserState &object,
int focus, FeatureVector *result) const override {
FeatureValue value = focus;
result->add(feature_type(), value);
}
FeatureValue Compute(const WorkspaceSet &workspaces, const ParserState &state,
int focus, const FeatureVector *result) const override {
return focus;
}
};
REGISTER_PARSER_IDX_FEATURE_FUNCTION("focus", FocusFeatureFunction);
} // namespace syntaxnet
......@@ -48,10 +48,11 @@ typedef FeatureFunction<ParserState, int> ParserIndexFeatureFunction;
// Utilities to register the two types of parser features.
#define REGISTER_PARSER_FEATURE_FUNCTION(name, component) \
REGISTER_FEATURE_FUNCTION(ParserFeatureFunction, name, component)
REGISTER_SYNTAXNET_FEATURE_FUNCTION(ParserFeatureFunction, name, component)
#define REGISTER_PARSER_IDX_FEATURE_FUNCTION(name, component) \
REGISTER_FEATURE_FUNCTION(ParserIndexFeatureFunction, name, component)
#define REGISTER_PARSER_IDX_FEATURE_FUNCTION(name, component) \
REGISTER_SYNTAXNET_FEATURE_FUNCTION(ParserIndexFeatureFunction, name, \
component)
// Alias for locator type that takes a parser state, and produces a focus
// integer that can be used on nested ParserIndexFeature objects.
......
......@@ -210,14 +210,14 @@ class ParserState {
int next_;
// Parse stack of partially processed tokens.
vector<int> stack_;
std::vector<int> stack_;
// List of head positions for the (partial) dependency tree.
vector<int> head_;
std::vector<int> head_;
// List of dependency relation labels describing the (partial) dependency
// tree.
vector<int> label_;
std::vector<int> label_;
// Score of the parser state.
double score_ = 0.0;
......
......@@ -17,6 +17,9 @@
# This test trains a parser on a small dataset, then runs it in greedy mode and
# in structured mode with beam 1, and checks that the result is identical.
set -eux
BINDIR=$TEST_SRCDIR/$TEST_WORKSPACE/syntaxnet
......
......@@ -20,7 +20,7 @@ limitations under the License.
namespace syntaxnet {
// Transition system registry.
REGISTER_CLASS_REGISTRY("transition system", ParserTransitionSystem);
REGISTER_SYNTAXNET_CLASS_REGISTRY("transition system", ParserTransitionSystem);
void ParserTransitionSystem::PerformAction(ParserAction action,
ParserState *state) const {
......
......@@ -118,7 +118,7 @@ class ParserTransitionSystem
// Returns all next gold actions for the parser during training using the
// dependency relations found in the underlying annotated sentence.
virtual void GetAllNextGoldActions(const ParserState &state,
vector<ParserAction> *actions) const {
std::vector<ParserAction> *actions) const {
ParserAction action = GetNextGoldAction(state);
*actions = {action};
}
......@@ -201,7 +201,7 @@ class ParserTransitionSystem
};
#define REGISTER_TRANSITION_SYSTEM(type, component) \
REGISTER_CLASS_COMPONENT(ParserTransitionSystem, type, component)
REGISTER_SYNTAXNET_CLASS_COMPONENT(ParserTransitionSystem, type, component)
} // namespace syntaxnet
......
......@@ -88,13 +88,13 @@ bool PopulateTestInputs::Populate(
PopulateTestInputs::Create PopulateTestInputs::CreateTFMapFromDocumentTokens(
const Sentence &document,
std::function<vector<string>(const Token &)> token2str) {
std::function<std::vector<string>(const Token &)> token2str) {
return [document, token2str](TaskInput *input) {
TermFrequencyMap map;
// Build and write the dummy term frequency map.
for (const Token &token : document.token()) {
vector<string> strings_for_token = token2str(token);
std::vector<string> strings_for_token = token2str(token);
for (const string &s : strings_for_token) map.Increment(s);
}
string file_name = AddPart(input, "text", "");
......@@ -116,22 +116,22 @@ PopulateTestInputs::Create PopulateTestInputs::CreateTagToCategoryFromTokens(
};
}
vector<string> PopulateTestInputs::TokenCategory(const Token &token) {
std::vector<string> PopulateTestInputs::TokenCategory(const Token &token) {
if (token.has_category()) return {token.category()};
return {};
}
vector<string> PopulateTestInputs::TokenLabel(const Token &token) {
std::vector<string> PopulateTestInputs::TokenLabel(const Token &token) {
if (token.has_label()) return {token.label()};
return {};
}
vector<string> PopulateTestInputs::TokenTag(const Token &token) {
std::vector<string> PopulateTestInputs::TokenTag(const Token &token) {
if (token.has_tag()) return {token.tag()};
return {};
}
vector<string> PopulateTestInputs::TokenWord(const Token &token) {
std::vector<string> PopulateTestInputs::TokenWord(const Token &token) {
if (token.has_word()) return {token.word()};
return {};
}
......
......@@ -130,17 +130,17 @@ class PopulateTestInputs {
// then saved to FLAGS_test_tmpdir/name.
static Create CreateTFMapFromDocumentTokens(
const Sentence &document,
std::function<vector<string>(const Token &)> token2str);
std::function<std::vector<string>(const Token &)> token2str);
// Creates a StringToStringMap protocol buffer input that maps tags to
// categories. Uses whatever mapping is present in the document.
static Create CreateTagToCategoryFromTokens(const Sentence &document);
// Default implementations for "token2str" above.
static vector<string> TokenCategory(const Token &token);
static vector<string> TokenLabel(const Token &token);
static vector<string> TokenTag(const Token &token);
static vector<string> TokenWord(const Token &token);
static std::vector<string> TokenCategory(const Token &token);
static std::vector<string> TokenLabel(const Token &token);
static std::vector<string> TokenTag(const Token &token);
static std::vector<string> TokenWord(const Token &token);
// Utility function. Sets the TaskInput->part() fields for a new input part.
// Returns the file name.
......
......@@ -131,7 +131,7 @@ class StdIn : public tensorflow::RandomAccessFile {
char *scratch) const {
memcpy(scratch, buffer_.data(), buffer_.size());
buffer_ = buffer_.substr(n);
result->set(scratch, n);
*result = tensorflow::StringPiece(scratch, n);
expected_offset_ += n;
}
......@@ -161,7 +161,7 @@ class TextReader {
Sentence *Read() {
// Skips emtpy sentences, e.g., blank lines at the beginning of a file or
// commented out blocks.
vector<Sentence *> sentences;
std::vector<Sentence *> sentences;
string key, value;
while (sentences.empty() && format_->ReadRecord(buffer_.get(), &value)) {
key = tensorflow::strings::StrCat(filename_, ":", sentence_count_);
......
......@@ -143,7 +143,7 @@ class ParsingReader : public OpKernel {
}
// Create the outputs for each feature space.
vector<Tensor *> feature_outputs(features_->NumEmbeddings());
std::vector<Tensor *> feature_outputs(features_->NumEmbeddings());
for (size_t i = 0; i < feature_outputs.size(); ++i) {
OP_REQUIRES_OK(context, context->allocate_output(
i, TensorShape({sentence_batch_->size(),
......@@ -399,7 +399,7 @@ class DecodedParseReader : public ParsingReader {
// pull from the back of the docids queue as long as the sentences have been
// completely processed. If the next document has not been completely
// processed yet, then the docid will not be found in 'sentence_map_'.
vector<Sentence> sentences;
std::vector<Sentence> sentences;
while (!docids_.empty() &&
sentence_map_.find(docids_.back()) != sentence_map_.end()) {
sentences.emplace_back(sentence_map_[docids_.back()]);
......@@ -427,7 +427,7 @@ class DecodedParseReader : public ParsingReader {
string scoring_type_;
mutable std::deque<string> docids_;
mutable map<string, Sentence> sentence_map_;
mutable std::map<string, Sentence> sentence_map_;
TF_DISALLOW_COPY_AND_ASSIGN(DecodedParseReader);
};
......
......@@ -28,11 +28,11 @@ limitations under the License.
// };
//
// #define REGISTER_FUNCTION(type, component)
// REGISTER_INSTANCE_COMPONENT(Function, type, component);
// REGISTER_SYNTAXNET_INSTANCE_COMPONENT(Function, type, component);
//
// function.cc:
//
// REGISTER_INSTANCE_REGISTRY("function", Function);
// REGISTER_SYNTAXNET_INSTANCE_REGISTRY("function", Function);
//
// class Cos : public Function {
// public:
......@@ -218,22 +218,22 @@ class RegisterableInstance {
static Registry registry_;
};
#define REGISTER_CLASS_COMPONENT(base, type, component) \
#define REGISTER_SYNTAXNET_CLASS_COMPONENT(base, type, component) \
static base *__##component##__factory() { return new component; } \
static base::Registry::Registrar __##component##__##registrar( \
base::registry(), type, #component, __FILE__, __LINE__, \
__##component##__factory)
#define REGISTER_CLASS_REGISTRY(type, classname) \
#define REGISTER_SYNTAXNET_CLASS_REGISTRY(type, classname) \
template <> \
classname::Registry RegisterableClass<classname>::registry_ = { \
type, #classname, __FILE__, __LINE__, NULL}
#define REGISTER_INSTANCE_COMPONENT(base, type, component) \
static base::Registry::Registrar __##component##__##registrar( \
#define REGISTER_SYNTAXNET_INSTANCE_COMPONENT(base, type, component) \
static base::Registry::Registrar __##component##__##registrar( \
base::registry(), type, #component, __FILE__, __LINE__, new component)
#define REGISTER_INSTANCE_REGISTRY(type, classname) \
#define REGISTER_SYNTAXNET_INSTANCE_REGISTRY(type, classname) \
template <> \
classname::Registry RegisterableInstance<classname>::registry_ = { \
type, #classname, __FILE__, __LINE__, NULL}
......
......@@ -38,7 +38,7 @@ const std::unordered_set<int> SegmenterUtils::kBreakChars({
});
void SegmenterUtils::GetUTF8Chars(const string &text,
vector<tensorflow::StringPiece> *chars) {
std::vector<tensorflow::StringPiece> *chars) {
const char *start = text.c_str();
const char *end = text.c_str() + text.size();
while (start < end) {
......@@ -50,7 +50,7 @@ void SegmenterUtils::GetUTF8Chars(const string &text,
void SegmenterUtils::SetCharsAsTokens(
const string &text,
const vector<tensorflow::StringPiece> &chars,
const std::vector<tensorflow::StringPiece> &chars,
Sentence *sentence) {
sentence->clear_token();
sentence->set_text(text);
......
......@@ -32,15 +32,15 @@ class SegmenterUtils {
// Takes a text and convert it into a vector, where each element is a utf8
// character.
static void GetUTF8Chars(const string &text,
vector<tensorflow::StringPiece> *chars);
std::vector<tensorflow::StringPiece> *chars);
// Sets tokens in the sentence so that each token is a single character.
// Assigns the start/end byte offsets.
//
// If the sentence is not empty, the current tokens will be cleared.
static void SetCharsAsTokens(const string &text,
const vector<tensorflow::StringPiece> &chars,
Sentence *sentence);
static void SetCharsAsTokens(
const string &text, const std::vector<tensorflow::StringPiece> &chars,
Sentence *sentence);
// Returns true for UTF-8 characters that cannot be 'real' tokens. This is
// defined as any whitespace, line break or paragraph break.
......
......@@ -59,9 +59,9 @@ static Sentence GetKoSentence() {
// Gets the start end bytes of the given chars in the given text.
static void GetStartEndBytes(const string &text,
const vector<tensorflow::StringPiece> &chars,
vector<int> *starts,
vector<int> *ends) {
const std::vector<tensorflow::StringPiece> &chars,
std::vector<int> *starts,
std::vector<int> *ends) {
SegmenterUtils segment_utils;
for (const tensorflow::StringPiece &c : chars) {
int start; int end;
......@@ -75,14 +75,14 @@ static void GetStartEndBytes(const string &text,
TEST(SegmenterUtilsTest, GetCharsTest) {
// Create test sentence.
const Sentence sentence = GetKoSentence();
vector<tensorflow::StringPiece> chars;
std::vector<tensorflow::StringPiece> chars;
SegmenterUtils::GetUTF8Chars(sentence.text(), &chars);
// Check the number of characters is correct.
CHECK_EQ(chars.size(), 12);
vector<int> starts;
vector<int> ends;
std::vector<int> starts;
std::vector<int> ends;
GetStartEndBytes(sentence.text(), chars, &starts, &ends);
// Check start positions.
......@@ -118,12 +118,12 @@ TEST(SegmenterUtilsTest, GetCharsTest) {
TEST(SegmenterUtilsTest, SetCharsAsTokensTest) {
// Create test sentence.
const Sentence sentence = GetKoSentence();
vector<tensorflow::StringPiece> chars;
std::vector<tensorflow::StringPiece> chars;
SegmenterUtils segment_utils;
segment_utils.GetUTF8Chars(sentence.text(), &chars);
vector<int> starts;
vector<int> ends;
std::vector<int> starts;
std::vector<int> ends;
GetStartEndBytes(sentence.text(), chars, &starts, &ends);
// Check that the new docs word, start and end positions are properly set.
......
......@@ -83,7 +83,8 @@ string TermFrequencyMapSetFeature::WorkspaceName() const {
}
namespace {
void GetUTF8Chars(const string &word, vector<tensorflow::StringPiece> *chars) {
void GetUTF8Chars(const string &word,
std::vector<tensorflow::StringPiece> *chars) {
UnicodeText text;
text.PointToUTF8(word.c_str(), word.size());
for (UnicodeText::const_iterator it = text.begin(); it != text.end(); ++it) {
......@@ -98,9 +99,10 @@ int UTF8FirstLetterNumBytes(const char *utf8_str) {
} // namespace
void CharNgram::GetTokenIndices(const Token &token, vector<int> *values) const {
void CharNgram::GetTokenIndices(const Token &token,
std::vector<int> *values) const {
values->clear();
vector<tensorflow::StringPiece> char_sp;
std::vector<tensorflow::StringPiece> char_sp;
if (use_terminators_) char_sp.push_back("^");
GetUTF8Chars(token.word(), &char_sp);
if (use_terminators_) char_sp.push_back("$");
......@@ -121,7 +123,7 @@ void CharNgram::GetTokenIndices(const Token &token, vector<int> *values) const {
}
void MorphologySet::GetTokenIndices(const Token &token,
vector<int> *values) const {
std::vector<int> *values) const {
values->clear();
const TokenMorphology &token_morphology =
token.GetExtension(TokenMorphology::morphology);
......@@ -401,7 +403,8 @@ string AffixTableFeature::GetFeatureValueName(FeatureValue value) const {
}
// Registry for the Sentence + token index feature functions.
REGISTER_CLASS_REGISTRY("sentence+index feature function", SentenceFeature);
REGISTER_SYNTAXNET_CLASS_REGISTRY("sentence+index feature function",
SentenceFeature);
// Register the features defined in the header.
REGISTER_SENTENCE_IDX_FEATURE("word", Word);
......
......@@ -110,7 +110,7 @@ class TokenLookupSetFeature : public SentenceFeature {
// feature value set. The index is relative to the start of the sentence.
virtual void LookupToken(const WorkspaceSet &workspaces,
const Sentence &sentence, int index,
vector<int> *values) const = 0;
std::vector<int> *values) const = 0;
// Given a feature value, returns a string representation.
virtual string GetFeatureValueName(int value) const = 0;
......@@ -137,7 +137,7 @@ class TokenLookupSetFeature : public SentenceFeature {
// Returns a pre-computed token value from the cache. This assumes the cache
// is populated.
const vector<int> &GetCachedValueSet(const WorkspaceSet &workspaces,
const std::vector<int> &GetCachedValueSet(const WorkspaceSet &workspaces,
const Sentence &sentence,
int focus) const {
// Do bounds checking on focus.
......@@ -152,7 +152,7 @@ class TokenLookupSetFeature : public SentenceFeature {
void Evaluate(const WorkspaceSet &workspaces, const Sentence &sentence,
int focus, FeatureVector *result) const override {
if (focus >= 0 && focus < sentence.token_size()) {
const vector<int> &elements =
const std::vector<int> &elements =
GetCachedValueSet(workspaces, sentence, focus);
for (auto &value : elements) {
result->add(this->feature_type(), value);
......@@ -233,7 +233,7 @@ class TermFrequencyMapSetFeature : public TokenLookupSetFeature {
// Returns index of raw word text.
virtual void GetTokenIndices(const Token &token,
vector<int> *values) const = 0;
std::vector<int> *values) const = 0;
// Requests the resource inputs.
void Setup(TaskContext *context) override;
......@@ -261,7 +261,7 @@ class TermFrequencyMapSetFeature : public TokenLookupSetFeature {
// Given a position in a sentence and workspaces, looks up the corresponding
// feature value set. The index is relative to the start of the sentence.
void LookupToken(const WorkspaceSet &workspaces, const Sentence &sentence,
int index, vector<int> *values) const override {
int index, std::vector<int> *values) const override {
GetTokenIndices(sentence.token(index), values);
}
......@@ -376,7 +376,8 @@ class CharNgram : public TermFrequencyMapSetFeature {
}
// Returns index of raw word text.
void GetTokenIndices(const Token &token, vector<int> *values) const override;
void GetTokenIndices(const Token &token,
std::vector<int> *values) const override;
private:
// Size parameter (n) for the ngrams.
......@@ -401,7 +402,8 @@ class MorphologySet : public TermFrequencyMapSetFeature {
}
// Returns index of raw word text.
void GetTokenIndices(const Token &token, vector<int> *values) const override;
void GetTokenIndices(const Token &token,
std::vector<int> *values) const override;
};
class LexicalCategoryFeature : public TokenLookupFeature {
......@@ -635,7 +637,7 @@ typedef FeatureExtractor<Sentence, int> SentenceExtractor;
// Utility to register the sentence_instance::Feature functions.
#define REGISTER_SENTENCE_IDX_FEATURE(name, type) \
REGISTER_FEATURE_FUNCTION(SentenceFeature, name, type)
REGISTER_SYNTAXNET_FEATURE_FUNCTION(SentenceFeature, name, type)
} // namespace syntaxnet
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment